-
Notifications
You must be signed in to change notification settings - Fork 515
Now, dpo.py matches dpo_tune_cache.py almost perfectly on the single GPU experiments
#1451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6ae4146
53ef2ef
74b1d4e
cdec970
893772e
3c477ba
075537d
a41e9a4
f04f080
d6bc330
2646b44
5e04417
0079553
6de7ad8
1974570
b4865ec
665707f
889bfe9
008fc86
c3feb3b
017c924
ced54f6
23e139e
016861c
844aa0b
e4575e5
fa3bd40
223622a
4d496b1
4ea43f4
9eec226
77ef874
9361ae7
1a289c2
cbe9960
1803432
a14af8f
1592470
26e8dfc
22fd8e2
8fd3771
31024f1
6143215
5a95094
72c7715
e01505b
faa15a4
f4631f9
f9ab894
7043d69
f8d6a37
0d0e930
4d55299
186b1f7
3fff585
a8f652a
19bf948
fc694ee
c5cb1bb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -85,7 +85,7 @@ def _load(): | |
|
|
||
|
|
||
| def _setup_model(args: dpo_utils.ExperimentConfig, device: torch.device): | ||
| """Load and configure OLMo-core model.""" | ||
| """Build OLMo-core model architecture (weights loaded after parallelization).""" | ||
| hf_config = transformers.AutoConfig.from_pretrained(args.model_name_or_path) | ||
| vocab_size = hf_config.vocab_size | ||
| logger.info(f"Building OLMo-core model with vocab_size={vocab_size}") | ||
|
|
@@ -103,10 +103,6 @@ def _setup_model(args: dpo_utils.ExperimentConfig, device: torch.device): | |
| ) | ||
| model = model_config.build(init_device="cpu") | ||
|
|
||
| logger.info(f"Loading HuggingFace weights from {args.model_name_or_path}") | ||
| load_hf_model(args.model_name_or_path, model.state_dict(), work_dir=args.output_dir) | ||
| model = model.to(device=device, dtype=torch.bfloat16) | ||
|
|
||
| return model, model_config | ||
|
|
||
|
|
||
|
|
@@ -271,7 +267,7 @@ def main(args: dpo_utils.ExperimentConfig, tc: dataset_transformation.TokenizerC | |
|
|
||
| dataset = _load_dataset_distributed(args, tc, transform_fn_args, is_main_process) | ||
| dataset = dataset.shuffle(seed=args.seed) | ||
| dataset.set_format(type="pt") # Must be after shuffle (shuffle resets format) | ||
| dataset.set_format(type="pt") | ||
|
|
||
| world_size = distributed_utils.get_world_size() if distributed_utils.is_distributed() else 1 | ||
| dp_world_size = world_size // args.tensor_parallel_degree | ||
|
|
@@ -308,6 +304,7 @@ def main(args: dpo_utils.ExperimentConfig, tc: dataset_transformation.TokenizerC | |
| work_dir=args.output_dir, | ||
| collator=collator, | ||
| device=device, | ||
| drop_last=True, | ||
| ) | ||
| # 4x batch size: forward-only (no backward), so no activation storage needed. | ||
| # With packing, the collator's token budget controls the actual forward-pass size | ||
|
|
@@ -325,7 +322,7 @@ def main(args: dpo_utils.ExperimentConfig, tc: dataset_transformation.TokenizerC | |
| work_dir=args.output_dir, | ||
| collator=collator, | ||
| device=device, | ||
| drop_last=False, | ||
| drop_last=True, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cache data loader drops examples, causing RuntimeErrorHigh Severity The |
||
| ) | ||
|
|
||
| forward_fn = dpo_utils.concatenated_forward_olmo if args.concatenated_forward else dpo_utils.separate_forward_olmo | ||
|
|
@@ -350,8 +347,9 @@ def main(args: dpo_utils.ExperimentConfig, tc: dataset_transformation.TokenizerC | |
|
|
||
| data_loader.reshuffle(epoch=0) | ||
| num_training_steps = len(data_loader) * args.num_epochs | ||
| effective_steps = args.max_train_steps if args.max_train_steps is not None else num_training_steps | ||
| optim_config = AdamWConfig(lr=args.learning_rate, weight_decay=args.weight_decay, fused=args.fused_optimizer) | ||
| scheduler = _setup_scheduler(args, num_training_steps) | ||
| scheduler = _setup_scheduler(args, effective_steps) | ||
| max_grad_norm = args.max_grad_norm if args.max_grad_norm > 0 else None | ||
| dp_config = transformer_config.TransformerDataParallelConfig( | ||
| name=DataParallelType.hsdp, | ||
|
|
@@ -384,9 +382,13 @@ def main(args: dpo_utils.ExperimentConfig, tc: dataset_transformation.TokenizerC | |
| device=device, | ||
| ) | ||
|
|
||
| # Build reference cache after train_module init because TransformerTrainModule applies | ||
| # FSDP parallelism to the model, and we need the parallelized model to calculate the | ||
| # logprobs in case the model is too big to fit in memory. | ||
| # TransformerTrainModule.__init__ calls parallelize_model which calls init_weights, | ||
finbarrtimbers marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # reinitializing all model weights from scratch. We must reload the HF checkpoint. | ||
| logger.info("Reloading HuggingFace weights after parallelization...") | ||
| sd = train_module.model.state_dict() | ||
| load_hf_model(args.model_name_or_path, sd, work_dir=args.output_dir) | ||
| train_module.model.load_state_dict(sd) | ||
|
|
||
| logger.info("Caching reference logprobs...") | ||
| train_module.reference_cache = dpo_utils.build_reference_logprobs_cache(model=train_module.model, **cache_kwargs) | ||
|
|
||
|
|
@@ -399,14 +401,21 @@ def main(args: dpo_utils.ExperimentConfig, tc: dataset_transformation.TokenizerC | |
|
|
||
| trainer_callbacks = _setup_callbacks(args, dp_world_size) | ||
|
|
||
| if args.max_train_steps is not None: | ||
| max_duration = train.Duration.steps(args.max_train_steps) | ||
| else: | ||
| max_duration = train.Duration.steps(num_training_steps) | ||
|
|
||
| trainer = train.TrainerConfig( | ||
| save_folder=args.output_dir, | ||
| max_duration=train.Duration.epochs(args.num_epochs), | ||
| max_duration=max_duration, | ||
| metrics_collect_interval=args.logging_steps, | ||
| callbacks=trainer_callbacks, | ||
| save_overwrite=True, | ||
| ).build(train_module, data_loader) | ||
|
|
||
| trainer.epoch = 0 | ||
|
|
||
| logger.info("Starting training...") | ||
| trainer.fit() | ||
| logger.info("Training complete.") | ||
|
|
||


Uh oh!
There was an error while loading. Please reload this page.