You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: src/instructlab/training/main_ds.py
+25-1Lines changed: 25 additions & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -359,6 +359,8 @@ def train(
359
359
tokenizer: PreTrainedTokenizer,
360
360
train_loader: DataLoader,
361
361
grad_accum,
362
+
metric_logger,
363
+
packing_max_batch_len: int,
362
364
):
363
365
model.train()
364
366
@@ -442,6 +444,22 @@ def train(
442
444
# loss = (
443
445
# loss / num_loss_counted_tokens * world_size
444
446
# ) # dividing by the total number of non-padding tokens and multiplying by the number of GPUs so when accelerate averages by world_size, it will be the correct loss.
447
+
448
+
# XXX(osilkin): Since we are accumulating gradients and sharding across cards, we need to
449
+
# ensure that the loss is correctly scaled across the entire mini-batch.
450
+
#
451
+
# To achieve this, we do the following:
452
+
# 1. Sum up the loss across nodes in a step then scale it by the inverse of
453
+
# the max possible loss tokens we might see (N)
454
+
# 2. Since the scaling factor is a constant, final gradients will look like: 1/N * g1 + 1/N * g2 + ... + 1/N * gn
455
+
# this means that for a given parameter, we can factor out the scalar: 1/N * (g1 + g2 + ... + gn)
456
+
# 3. At the end of the batch, we've counted up the true number of loss tokens seen T, so we create the adjusted
457
+
# scalar C = N/T and compute as a single number
458
+
# 4. Multiply the final gradient by this scalar so everything cancels out correctly:
459
+
# (N/T) * 1/N * (g1 + g2 + ... + gn)
460
+
loss= (
461
+
loss/packing_max_batch_len
462
+
) # dividing by the total number of non-padding tokens and multiplying by the number of GPUs so when accelerate averages by world_size, it will be the correct loss.
445
463
base_logger.info(
446
464
f"Epoch: {epoch}, Step: {global_step}, Rank: {torch.distributed.get_rank()}, loss = {loss}"
447
465
)
@@ -454,10 +472,14 @@ def train(
454
472
# to numeric instability
455
473
#
456
474
# Also I'm not sure if this will work with FSDP Full Shard
475
+
# XXX(osilkin): not sure how this will work when doing full shard
0 commit comments