Skip to content

Bug Report: DTensor/torch.Tensor Mixed Type Error in Llama4 LoRA Fine-tuning #2856

@ronyadgar

Description

@ronyadgar

I follow this cookbook in order to finetune LLaMa4 (https://github.com/meta-llama/llama-cookbook/blob/main/getting-started/finetuning/finetune_llama4.md)

Bug Description

Getting RuntimeError: aten.mm.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators! when running distributed LoRA fine-tuning on Llama4-Scout-17B-16E model.

Error Message

RuntimeError: aten.mm.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!

To Reproduce

Command:
tune run --nproc_per_node 8 lora_finetune_distributed --config llama4/scout_17B_16E_lora

Configuration (scout_17B_16E_lora.yaml)

# Config for multi-device full finetuning in full_finetune_distributed.py
# using a Llama4 17Bx16E MoE model
#
# This config assumes that you've run the following command before launching:
#   tune download meta-llama/Llama-4-Scout-17B-16E-Instruct
#
# To launch on 8 devices, run the following command from root:
#   tune run --nproc_per_node 8 lora_finetune_distributed --config llama4/scout_17B_16E_lora
#
# You can add specific overrides through the command line. For example, to use a larger bsz:
#   tune run --nproc_per_node 8 lora_finetune_distributed --config llama4/scout_17B_16E_lora batch_size=8
#
# This config was only tested on 8xA100 machine.

output_dir: /tmp/torchtune/llama4_17Bx16E/lora

# Modeling Arguments
model:
  _component_: torchtune.models.llama4.lora_llama4_scout_17b_16e
  decoder_trainable: "lora"
  encoder_trainable: "frozen"
  fusion_trainable: "lora"
  lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
  apply_lora_to_mlp: True
  apply_lora_to_output: False
  lora_rank: 16  # higher increases accuracy and memory
  lora_alpha: 32  # usually alpha=2*rank
  lora_dropout: 0.0

tokenizer:
  _component_: torchtune.models.llama4.llama4_transform
  path: /tmp/Llama-4-Scout-17B-16E-Instruct/tokenizer.model
  max_seq_len: null
  max_num_tiles: 16

checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Llama-4-Scout-17B-16E-Instruct
  checkpoint_files:
    filename_format: model-{}-of-{}.safetensors
    max_filename: "00050"
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA4
save_adapter_weights_only: True # use this for faster checkpoint save
resume_from_checkpoint: False

# Dataset and Sampler
dataset:
  _component_: torchtune.datasets.alpaca_cleaned_dataset
  packed: False  # True increases speed
  split: train[:95%]
seed: null
shuffle: True

# Validation
run_val_every_n_steps: null  # Change to an integer to enable validation every N steps
dataset_val:
  _component_: torchtune.datasets.alpaca_cleaned_dataset
  split: train[95%:]
batch_size_val: ${batch_size}

# Training arguments
epochs: 1
batch_size: 2
max_steps_per_epoch: null
gradient_accumulation_steps: 1 # Use to increase effective batch size
optimizer:
  _component_: torch.optim.AdamW
  lr: 2e-5
  fused: False
optimizer_in_bwd: False
lr_scheduler:
  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
  num_warmup_steps: 100
loss:
  _component_: torchtune.modules.loss.LinearCrossEntropyLoss
clip_grad_norm: null

# cuda, cpu, rocm, xpu...
device: cuda

# Memory management / performance
enable_activation_checkpointing: True
enable_activation_offloading: False
custom_sharded_layers: ['tok_embeddings']
fsdp_cpu_offload: False
compile: False # torch.compile, set to true for perf/memory improvement

# Reduced precision
dtype: bf16

# Log metrics during training
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
log_level: INFO  # DEBUG, WARN, etc.

# Useful for understanding how to optimize memory and performance
profiler:
  _component_: torchtune.training.setup_torch_profiler
  enabled: False

Environment Information
pytorch
PyTorch: 2.8.0.dev20250625+cu126
cdua version
CUDA: 12.6
CUDA devices: 8
python --version
Python 3.10.18
OS system
Linux ip-10-207-79-120 6.8.0-1030-aws #32-Ubuntu SMP Wed May 28 19:48:56 UTC 2025 x86_64 x86_64 x86_64 GNU/Linux
nvidia-smi

py310) ubuntu@ip-10-207-79-120:~/torchtune/recipes/configs/llama4$ nvidia-smi
Sun Jun 29 10:56:45 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.133.20             Driver Version: 570.133.20     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100-SXM4-80GB          On  |   00000000:10:1C.0 Off |                    0 |
| N/A   49C    P0             69W /  400W |       0MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-SXM4-80GB          On  |   00000000:10:1D.0 Off |                    0 |
| N/A   46C    P0             68W /  400W |       0MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A100-SXM4-80GB          On  |   00000000:20:1C.0 Off |                    0 |
| N/A   49C    P0             72W /  400W |       0MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA A100-SXM4-80GB          On  |   00000000:20:1D.0 Off |                    0 |
| N/A   43C    P0             65W /  400W |       0MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA A100-SXM4-80GB          On  |   00000000:90:1C.0 Off |                    0 |
| N/A   47C    P0             66W /  400W |       0MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA A100-SXM4-80GB          On  |   00000000:90:1D.0 Off |                    0 |
| N/A   44C    P0             70W /  400W |       0MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA A100-SXM4-80GB          On  |   00000000:A0:1C.0 Off |                    0 |
| N/A   45C    P0             69W /  400W |       0MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA A100-SXM4-80GB          On  |   00000000:A0:1D.0 Off |                    0 |
| N/A   42C    P0             68W /  400W |       0MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions