Skip to content

Commit 5ceb424

Browse files
committed
introduce intermediate scaling for numerical stability
Signed-off-by: Oleg S <[email protected]>
1 parent da836b0 commit 5ceb424

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

src/instructlab/training/main_ds.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,8 @@ def train(
359359
tokenizer: PreTrainedTokenizer,
360360
train_loader: DataLoader,
361361
grad_accum,
362+
metric_logger,
363+
packing_max_batch_len: int,
362364
):
363365
model.train()
364366

@@ -442,6 +444,22 @@ def train(
442444
# loss = (
443445
# loss / num_loss_counted_tokens * world_size
444446
# ) # dividing by the total number of non-padding tokens and multiplying by the number of GPUs so when accelerate averages by world_size, it will be the correct loss.
447+
448+
# XXX(osilkin): Since we are accumulating gradients and sharding across cards, we need to
449+
# ensure that the loss is correctly scaled across the entire mini-batch.
450+
#
451+
# To achieve this, we do the following:
452+
# 1. Sum up the loss across nodes in a step then scale it by the inverse of
453+
# the max possible loss tokens we might see (N)
454+
# 2. Since the scaling factor is a constant, final gradients will look like: 1/N * g1 + 1/N * g2 + ... + 1/N * gn
455+
# this means that for a given parameter, we can factor out the scalar: 1/N * (g1 + g2 + ... + gn)
456+
# 3. At the end of the batch, we've counted up the true number of loss tokens seen T, so we create the adjusted
457+
# scalar C = N/T and compute as a single number
458+
# 4. Multiply the final gradient by this scalar so everything cancels out correctly:
459+
# (N/T) * 1/N * (g1 + g2 + ... + gn)
460+
loss = (
461+
loss / packing_max_batch_len
462+
) # dividing by the total number of non-padding tokens and multiplying by the number of GPUs so when accelerate averages by world_size, it will be the correct loss.
445463
base_logger.info(
446464
f"Epoch: {epoch}, Step: {global_step}, Rank: {torch.distributed.get_rank()}, loss = {loss}"
447465
)
@@ -454,10 +472,14 @@ def train(
454472
# to numeric instability
455473
#
456474
# Also I'm not sure if this will work with FSDP Full Shard
475+
# XXX(osilkin): not sure how this will work when doing full shard
476+
corrected_scalar = float(packing_max_batch_len) / float(
477+
total_minibatch_loss_tokens_seen
478+
)
457479
for p in model.parameters():
458480
grad = p.grad
459481
assert grad is not None
460-
grad.mul_(1.0 / total_minibatch_loss_tokens_seen)
482+
grad.mul_(corrected_scalar)
461483

462484
global_grad_norm = accelerator.clip_grad_norm_(model.parameters(), 1.0)
463485
optimizer.step()
@@ -700,6 +722,8 @@ def main(args):
700722
tokenizer,
701723
train_loader,
702724
grad_accum,
725+
metric_logger,
726+
packing_max_batch_len,
703727
)
704728

705729
torch.distributed.barrier()

0 commit comments

Comments
 (0)