Skip to content

Commit 5402b29

Browse files
tcapelleThomas CapellekartikaykyechenzhiRdoubleA
authored
Improve Wandb experience (#660)
Co-authored-by: Thomas Capelle <[email protected]> Co-authored-by: Kartikay Khandelwal <[email protected]> Co-authored-by: yechenzhi <[email protected]> Co-authored-by: Rafi Ayub <[email protected]> Co-authored-by: Rohan Varma <[email protected]> Co-authored-by: Joe Cummings <[email protected]> Co-authored-by: Botao Chen <[email protected]> Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: ebsmothers <[email protected]> Co-authored-by: Mengtao Yuan <[email protected]> Co-authored-by: solitude-alive <[email protected]> Co-authored-by: Jerry Zhang <[email protected]> Co-authored-by: RdoubleA <[email protected]>
1 parent 053d0ae commit 5402b29

File tree

11 files changed

+239
-77
lines changed

11 files changed

+239
-77
lines changed
645 KB
Loading
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
.. _wandb_logging:
2+
3+
===========================
4+
Logging to Weights & Biases
5+
===========================
6+
7+
.. customcarditem::
8+
:header: Logging to Weights & Biases
9+
:card_description: Log metrics and model checkpoints to W&B
10+
:image: _static/img/torchtune_workspace.png
11+
:link: examples/wandb_logging.html
12+
:tags: logging,wandb
13+
14+
15+
Torchtune supports logging your training runs to [Weights & Biases](https://wandb.ai).
16+
17+
.. note::
18+
19+
You will need to install the `wandb` package to use this feature.
20+
You can install it via pip:
21+
22+
.. code-block:: bash
23+
24+
pip install wandb
25+
26+
Then you need to login with your API key using the W&B CLI:
27+
28+
.. code-block:: bash
29+
30+
wandb login
31+
32+
33+
Metric Logger
34+
-------------
35+
36+
The only change you need to make is to add the metric logger to your config. Weights & Biases will log the metrics and model checkpoints for you.
37+
38+
.. code-block:: yaml
39+
40+
# enable logging to the built-in WandBLogger
41+
metric_logger:
42+
_component_: torchtune.utils.metric_logging.WandBLogger
43+
# the W&B project to log to
44+
project: torchtune
45+
46+
47+
We automatically grab the config from the recipe you are running and log it to W&B. You can find it in the W&B overview tab and the actual file in the `Files` tab.
48+
49+
.. note::
50+
51+
Click on this sample [project to see the W&B workspace](https://wandb.ai/capecape/torchtune)
52+
The config used to train the models can be found [here](https://wandb.ai/capecape/torchtune/runs/6053ofw0/files/torchtune_config_j67sb73v.yaml)
53+
54+
Logging Model Checkpoints to W&B
55+
--------------------------------
56+
57+
You can also log the model checkpoints to W&B by modifying the desired script `save_checkpoint` method.
58+
59+
A suggested approach would be something like this:
60+
61+
.. code-block:: python
62+
63+
def save_checkpoint(self, epoch: int) -> None:
64+
...
65+
## Let's save the checkpoint to W&B
66+
## depending on the Checkpointer Class the file will be named differently
67+
## Here is an example for the full_finetune case
68+
checkpoint_file = Path.joinpath(
69+
self._checkpointer._output_dir, f"torchtune_model_{epoch}"
70+
).with_suffix(".pt")
71+
wandb_at = wandb.Artifact(
72+
name=f"torchtune_model_{epoch}",
73+
type="model",
74+
# description of the model checkpoint
75+
description="Model checkpoint",
76+
# you can add whatever metadata you want as a dict
77+
metadata={
78+
utils.SEED_KEY: self.seed,
79+
utils.EPOCHS_KEY: self.epochs_run,
80+
utils.TOTAL_EPOCHS_KEY: self.total_epochs,
81+
utils.MAX_STEPS_KEY: self.max_steps_per_epoch,
82+
}
83+
)
84+
wandb_at.add_file(checkpoint_file)
85+
wandb.log_artifact(wandb_at)

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ torchtune tutorials.
102102
examples/checkpointer
103103
examples/configs
104104
examples/recipe_deepdive
105+
examples/wandb_logging
105106

106107
.. toctree::
107108
:glob:

recipes/full_finetune_distributed.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,11 @@ def setup(self, cfg: DictConfig) -> None:
171171
Sets up the recipe state correctly. This includes setting recipe attributes based
172172
on the ``resume_from_checkpoint`` flag.
173173
"""
174-
self._metric_logger = config.instantiate(cfg.metric_logger)
174+
if self._is_rank_zero:
175+
self._metric_logger = config.instantiate(cfg.metric_logger)
176+
177+
# log config with parameter override
178+
self._metric_logger.log_config(cfg)
175179

176180
ckpt_dict = self.load_checkpoint(cfg.checkpointer)
177181

@@ -291,11 +295,8 @@ def _setup_model(
291295
model, auto_wrap_policy={modules.TransformerDecoderLayer}
292296
)
293297
if self._is_rank_zero:
294-
log.info(
295-
utils.memory_stats_log(
296-
"Memory Stats after model init", device=self._device
297-
)
298-
)
298+
memory_stats = utils.memory_stats_log(device=self._device)
299+
log.info(f"Memory Stats after model init:\n{memory_stats}")
299300

300301
# synchronize before training begins
301302
torch.distributed.barrier()
@@ -475,15 +476,18 @@ def train(self) -> None:
475476
self.total_training_steps % self._log_peak_memory_every_n_steps == 0
476477
and self._is_rank_zero
477478
):
478-
log.info(
479-
utils.memory_stats_log("Memory Stats", device=self._device)
479+
# Log peak memory for iteration
480+
memory_stats = utils.memory_stats_log(device=self._device)
481+
self._metric_logger.log_dict(
482+
memory_stats, step=self.total_training_steps
480483
)
481484

482485
self.epochs_run += 1
483486
self.save_checkpoint(epoch=curr_epoch)
484487

485488
def cleanup(self) -> None:
486-
self._metric_logger.close()
489+
if self._is_rank_zero:
490+
self._metric_logger.close()
487491
torch.distributed.destroy_process_group()
488492

489493

recipes/full_finetune_single_device.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,9 @@ def setup(self, cfg: DictConfig) -> None:
176176
"""
177177
self._metric_logger = config.instantiate(cfg.metric_logger)
178178

179+
# log config with parameter override
180+
self._metric_logger.log_config(cfg)
181+
179182
ckpt_dict = self.load_checkpoint(cfg.checkpointer)
180183

181184
# ``_setup_model`` handles initialization and loading the state dict. This method
@@ -257,11 +260,9 @@ def _setup_model(
257260
if compile_model:
258261
log.info("Compiling model with torch.compile...")
259262
model = utils.wrap_compile(model)
260-
log.info(
261-
utils.memory_stats_log(
262-
"Memory Stats after model init:", device=self._device
263-
)
264-
)
263+
if self._device == torch.device("cuda"):
264+
memory_stats = utils.memory_stats_log(device=self._device)
265+
log.info(f"Memory Stats after model init:\n{memory_stats}")
265266
return model
266267

267268
def _setup_optimizer(
@@ -440,9 +441,13 @@ def train(self) -> None:
440441
self.total_training_steps += 1
441442

442443
# Log peak memory for iteration
443-
if self.total_training_steps % self._log_peak_memory_every_n_steps == 0:
444-
log.info(
445-
utils.memory_stats_log("Memory Stats:", device=self._device)
444+
if (
445+
self.total_training_steps % self._log_peak_memory_every_n_steps == 0
446+
and self._device == torch.device("cuda")
447+
):
448+
memory_stats = utils.memory_stats_log(device=self._device)
449+
self._metric_logger.log_dict(
450+
memory_stats, step=self.total_training_steps
446451
)
447452
self.epochs_run += 1
448453
self.save_checkpoint(epoch=curr_epoch)

recipes/gemma_full_finetune_distributed.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,11 @@ def setup(self, cfg: DictConfig) -> None:
146146
Sets up the recipe state correctly. This includes setting recipe attributes based
147147
on the ``resume_from_checkpoint`` flag.
148148
"""
149-
self._metric_logger = config.instantiate(cfg.metric_logger)
149+
if self._is_rank_zero:
150+
self._metric_logger = config.instantiate(cfg.metric_logger)
151+
152+
# log config with parameter override
153+
self._metric_logger.log_config(cfg)
150154

151155
ckpt_dict = self.load_checkpoint(cfg.checkpointer)
152156

@@ -263,12 +267,8 @@ def _setup_model(
263267
model, auto_wrap_policy={modules.TransformerDecoderLayer}
264268
)
265269
if self._is_rank_zero:
266-
log.info(
267-
utils.memory_stats_log(
268-
"Memory Stats after model init", device=self._device
269-
)
270-
)
271-
270+
memory_stats = utils.memory_stats_log(device=self._device)
271+
log.info(f"Memory Stats after model init:\n{memory_stats}")
272272
# synchronize before training begins
273273
torch.distributed.barrier()
274274

@@ -458,15 +458,18 @@ def train(self) -> None:
458458
self.total_training_steps % self._log_peak_memory_every_n_steps == 0
459459
and self._is_rank_zero
460460
):
461-
log.info(
462-
utils.memory_stats_log("Memory Stats", device=self._device)
461+
# Log peak memory for iteration
462+
memory_stats = utils.memory_stats_log(device=self._device)
463+
self._metric_logger.log_dict(
464+
memory_stats, step=self.total_training_steps
463465
)
464466

465467
self.epochs_run += 1
466468
self.save_checkpoint(epoch=curr_epoch)
467469

468470
def cleanup(self) -> None:
469-
self._metric_logger.close()
471+
if self._is_rank_zero:
472+
self._metric_logger.close()
470473
torch.distributed.destroy_process_group()
471474

472475

recipes/lora_dpo_single_device.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ def setup(self, cfg: DictConfig) -> None:
148148
"""
149149
self._metric_logger = config.instantiate(cfg.metric_logger)
150150

151+
# log config with parameter override
152+
self._metric_logger.log_config(cfg)
153+
151154
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
152155

153156
self._model = self._setup_model(
@@ -252,11 +255,9 @@ def _setup_model(
252255
)
253256

254257
log.info(f"Model is initialized with precision {self._dtype}.")
255-
log.info(
256-
utils.memory_stats_log(
257-
"Memory Stats after model init:", device=self._device
258-
)
259-
)
258+
if self._device == torch.device("cuda"):
259+
memory_stats = utils.memory_stats_log(device=self._device)
260+
log.info(f"Memory Stats after model init:\n{memory_stats}")
260261
return model
261262

262263
def _setup_optimizer(
@@ -490,9 +491,14 @@ def train(self) -> None:
490491
# Update the number of steps when the weights are updated
491492
self.total_training_steps += 1
492493
# Log peak memory for iteration
493-
if self.total_training_steps % self._log_peak_memory_every_n_steps == 0:
494-
log.info(
495-
utils.memory_stats_log("Memory Stats:", device=self._device)
494+
if (
495+
self.total_training_steps % self._log_peak_memory_every_n_steps == 0
496+
and self._device == torch.device("cuda")
497+
):
498+
# Log peak memory for iteration
499+
memory_stats = utils.memory_stats_log(device=self._device)
500+
self._metric_logger.log_dict(
501+
memory_stats, step=self.total_training_steps
496502
)
497503
self.epochs_run += 1
498504
self.save_checkpoint(epoch=curr_epoch)

recipes/lora_finetune_distributed.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,9 @@ def setup(self, cfg: DictConfig) -> None:
197197
if self._is_rank_zero:
198198
self._metric_logger = config.instantiate(cfg.metric_logger)
199199

200+
# log config with parameter override
201+
self._metric_logger.log_config(cfg)
202+
200203
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
201204

202205
self._model = self._setup_model(
@@ -353,11 +356,8 @@ def _setup_model(
353356
model, auto_wrap_policy={modules.TransformerDecoderLayer}
354357
)
355358
if self._is_rank_zero:
356-
log.info(
357-
utils.memory_stats_log(
358-
"Memory Stats after model init:", device=self._device
359-
)
360-
)
359+
memory_stats = utils.memory_stats_log(device=self._device)
360+
log.info(f"Memory Stats after model init:\n{memory_stats}")
361361

362362
# synchronize before training begins
363363
torch.distributed.barrier()
@@ -571,8 +571,10 @@ def train(self) -> None:
571571
self.total_training_steps % self._log_peak_memory_every_n_steps == 0
572572
and self._is_rank_zero
573573
):
574-
log.info(
575-
utils.memory_stats_log("Memory Stats:", device=self._device)
574+
# Log peak memory for iteration
575+
memory_stats = utils.memory_stats_log(device=self._device)
576+
self._metric_logger.log_dict(
577+
memory_stats, step=self.total_training_steps
576578
)
577579

578580
self.epochs_run += 1

recipes/lora_finetune_single_device.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,10 @@ def setup(self, cfg: DictConfig) -> None:
174174
model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader.
175175
"""
176176
self._metric_logger = config.instantiate(cfg.metric_logger)
177+
178+
# log config with parameter override
179+
self._metric_logger.log_config(cfg)
180+
177181
self._model_compile = cfg.compile
178182
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
179183

@@ -291,11 +295,9 @@ def _setup_model(
291295
if compile_model:
292296
log.info("Compiling model with torch.compile...")
293297
model = utils.wrap_compile(model)
294-
log.info(
295-
utils.memory_stats_log(
296-
"Memory Stats after model init:", device=self._device
297-
)
298-
)
298+
if self._device == torch.device("cuda"):
299+
memory_stats = utils.memory_stats_log(device=self._device)
300+
log.info(f"Memory Stats after model init:\n{memory_stats}")
299301
return model
300302

301303
def _setup_optimizer(
@@ -474,9 +476,12 @@ def train(self) -> None:
474476
if (
475477
self.total_training_steps % self._log_peak_memory_every_n_steps
476478
== 0
479+
and self._device == torch.device("cuda")
477480
):
478-
log.info(
479-
utils.memory_stats_log("Memory Stats:", device=self._device)
481+
# Log peak memory for iteration
482+
memory_stats = utils.memory_stats_log(device=self._device)
483+
self._metric_logger.log_dict(
484+
memory_stats, step=self.total_training_steps
480485
)
481486
self.epochs_run += 1
482487
self.save_checkpoint(epoch=curr_epoch)

0 commit comments

Comments
 (0)