Open
Description
I'm opening this issue after asking the question in the Discussions @innat wrote that it is better to open a ticket here… So this is an issue !
I'm trying to fine-tune Meta's LLaMA 3.2 3B Instruct model using Keras on TPU v3-8 (Kaggle). While the code runs without errors, the loss remains constant during training.
Context:
- My wife, a history and geography teacher, writes short evaluations for her students every three months. Each evaluation includes three indicators: attitude in class (0-10), personal work (0-10), and participation in class (0-10), along with the trimestrial period (1, 2, or 3) and the mean score (0-20). It wrapped her evaluation in a HF dataset (french)
- I successfully fine-tuned the same model with the same dataset using HuggingFace's SFTTrainer and Unsloth GPU optimizations (working model here) and the notebook used for fine-tuning.
- The model works well, and a demo runs on HF Spaces at this link it is very slow on CPU but if you have access to a ZeroGPU just duplicate the space and it is really fast.
- The current notebook designed for running on Kaggle with Google TPU v3x8 is available here.
Technical Setup:
- TPU v3-8 on Kaggle
- tensorflow==2.16.2
- keras==3.0.5
- Base model: meta-llama/Llama-3.2-3B-Instruct
Here's the relevant code copied from full notebook:
# TPU Setup
devices = keras.distribution.list_devices()
device_mesh = keras.distribution.DeviceMesh(
(1, 8),
["batch", "model"],
devices=keras.distribution.list_devices())
layout_map = keras_hub.models.Llama3Backbone.get_layout_map(device_mesh) # default layout_map
distrib = keras.distribution.ModelParallel(layout_map=layout_map, batch_dim_name="batch")
keras.distribution.set_distribution(distrib)
# Model initialization
llama_model = keras_hub.models.Llama3CausalLM.from_preset("hf://meta-llama/Llama-3.2-3B-Instruct")
llama_model.backbone.enable_lora(rank=8)
llama_model.preprocessor.sequence_length = 256
optimizer = keras.optimizers.AdamW(
learning_rate=2e-4,
weight_decay=0.01,
)
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])
# dataset adaptation
# sample dataset content
# multi_turn_dataset is an instance of HF datasets.DatasetDict
# multi_turn_dataset['train'][42]['text']
# "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 July 2024\n\nVous êtes une IA assistant les enseignants d'histoire-géographie en rédigeant à leur place une appréciation personnalisée pour leur élève en fonction de ses performances. Votre appréciation doit être en français formel et impersonnel. Votre appréciation doit être bienveillante, constructive, et aider l'élève à comprendre ses points forts et les axes d'amélioration. Votre appréciation doit comporter de 8 à 250 caractères. Votre appréciation ne doit jamais comporter les valeurs des notes. <|eot_id|><|start_header_id|>user<|end_header_id|>\n\nVeuillez rédiger une appréciation en moins de 250 caractères pour le premier trimestre pour cet élève qui a eu 14.0 de moyenne, j'ai évalué son comportement à 4.8/10, sa participation à 3.1/10 et son travail à 3.9/10. Les notes ne doivent pas apparaître dans l'appréciation.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nL'ensemble est correct mais vous pourrez certainement progresser au second trimestre en faisant davantage d'efforts de participation et dans le travail personnel. Attention également à ne pas se déconcentrer en cours.<|eot_id|>"
llama_model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
#jit_compile="auto",
#auto_scale_loss=True,
)
# Training
llama_model.fit(multi_turn_dataset['train'].with_format("tf")['text'],
validation_data=multi_turn_dataset['validation'].with_format("tf")['text'],
epochs=3,
verbose="auto",)
Issue:
The loss stays constant at ~2.9 throughout training with no signs of learning.
Questions:
- Why isn't the model learning despite using similar parameters to my successful HuggingFace implementation?
- Are there specific considerations when fine-tuning LLMs with Keras on TPU?
- Is my loss function appropriate for this task?
Thank you for your help!