Skip to content

Commit 96a63f3

Browse files
committed
starting fp8 logging
Signed-off-by: Jonathan Mitchell <[email protected]>
1 parent 20076cc commit 96a63f3

File tree

4 files changed

+50
-30
lines changed

4 files changed

+50
-30
lines changed

.devcontainer/recipes/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@ transformers
1414
typer
1515
wandb
1616
zstandard
17+
nvdlfw_inspect @ git+https://github.com/NVIDIA/nvidia-dlfw-inspect
18+
19+
Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,23 @@
11
example_fp8_tensor_stat_collection:
22
enabled: True
33
layers:
4-
layer_types: [layernorm_linear]
4+
# Match the actual linear layers within attention that support FP8 stats
5+
layer_types: [layernorm_qkv, proj]
56
transformer_engine:
67
LogFp8TensorStats:
78
enabled: True
89
tensors_struct:
910
- tensor: activation
10-
stats: [fp8_block_scaling_underflows%]
11+
stats: [underflows%, overflows%]
1112
freq: 1
1213
- tensor: activation
13-
stats: [fp8_block_scaling_overflows%]
14+
stats: [scale_inv_min, scale_inv_max]
1415
freq: 1
1516
- tensor: activation
16-
stats: [fp8_block_scaling_scale_inv_min]
17-
freq: 1
18-
- tensor: activation
19-
stats: [fp8_block_scaling_scale_inv_max]
20-
freq: 1
21-
- tensor: activation
22-
stats: [fp8_block_scaling_mse]
17+
stats: [mse]
2318
freq: 1
2419
- tensor: gradient
2520
stats: [underflows%]
2621
freq: 5
2722
start_step: 0
28-
end_step: 80
23+
end_step: 80

bionemo-recipes/recipes/esm2_native_te/train_ddp.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import logging
1717
from pathlib import Path
1818

19+
import nvdlfw_inspect.api as debug_api
1920
import hydra
2021
import torch
2122
import transformer_engine.pytorch
@@ -43,6 +44,13 @@ def main(args: DictConfig) -> float | None:
4344
Returns:
4445
float: The loss value for the final batch.
4546
"""
47+
# TE Debug feature logging - MUST be done BEFORE FSDP wrapping
48+
debug_api.initialize(
49+
config_file="/workspaces/bionemo-framework/bionemo-recipes/recipes/esm2_native_te/fp8_stats_block_scaling_ddp.yaml",
50+
feature_dirs=["/usr/local/lib/python3.12/dist-packages/transformer_engine/debug/features/"],
51+
log_dir="./logddp",
52+
default_logging_enabled=True,
53+
)
4654
# Initialize the distributed configuration, including creating the distributed process group.
4755
dist_config = DistributedConfig()
4856
logger.info("Initializing distributed training: %s", dist_config)
@@ -65,6 +73,7 @@ def main(args: DictConfig) -> float | None:
6573
if args.use_sequence_packing:
6674
config.attn_input_format = "thd"
6775

76+
6877
# Optionally use transformer engine to initialize only fp8 versions of weights by setting
6978
# `fp8_config.fp8_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 and fp8
7079
# versions of weights are kept.
@@ -80,10 +89,14 @@ def main(args: DictConfig) -> float | None:
8089
except AttributeError:
8190
pass
8291

92+
93+
8394
# Create optimizer.
8495
optimizer = AdamW(model.parameters(), **args.adamw_kwargs)
8596
scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs)
8697

98+
debug_api.infer_and_assign_layer_names(model)
99+
87100
model = model.to(device=device)
88101
model = torch.nn.parallel.DistributedDataParallel(
89102
model,
@@ -99,9 +112,9 @@ def main(args: DictConfig) -> float | None:
99112
else create_bshd_dataloader(dist_config, **args.dataset)
100113
)
101114

102-
if args.use_torch_compile:
103-
# If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency.
104-
model = torch.compile(model)
115+
# if args.use_torch_compile:
116+
# # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency.
117+
# model = torch.compile(model)
105118

106119
# If we're resuming from a checkpoint, load it and set the start step. Otherwise, start from step 0.
107120
ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_ddp" if args.checkpoint.ckpt_dir else None
@@ -134,6 +147,7 @@ def main(args: DictConfig) -> float | None:
134147
loss = outputs.loss
135148
loss.backward()
136149

150+
debug_api.step()
137151
# Compute and clip gradient norms.
138152
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()
139153

bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pathlib import Path
1919

2020
import hydra
21+
import nvdlfw_inspect.api as debug_api
2122
import torch
2223
import transformer_engine.pytorch
2324
from omegaconf import DictConfig, OmegaConf
@@ -36,8 +37,6 @@
3637
from perf_logger import PerfLogger
3738
from scheduler import get_linear_schedule_with_warmup
3839

39-
import nvdlfw_inspect.api as debug_api
40-
4140

4241
logger = logging.getLogger(__name__)
4342
logger.setLevel(logging.INFO)
@@ -86,12 +85,31 @@ def main(args: DictConfig) -> float | None:
8685

8786
logger.info("Initialized Model:\n%s", model)
8887

88+
# TE Debug feature logging - MUST be done BEFORE FSDP wrapping
89+
debug_api.initialize(
90+
config_file="/workspaces/bionemo-framework/bionemo-recipes/recipes/esm2_native_te/fp8_stats_block_scaling.yaml",
91+
feature_dirs=["/usr/local/lib/python3.12/dist-packages/transformer_engine/debug/features/"],
92+
log_dir="./log",
93+
default_logging_enabled=True,
94+
)
95+
# Debug: Print module types to verify what we're working with
96+
if dist_config.local_rank == 0:
97+
logger.info("=== DEBUG: Module types in model ===")
98+
for name, module in model.named_modules():
99+
if 'layernorm_qkv' in name or 'proj' in name or 'self_attention' in name:
100+
logger.info(f" -----> {name}: {type(module)} <----")
101+
logger.info(f"=== DEBUG: FP8 config enabled={args.fp8_config.enabled}, recipe={args.fp8_config.fp8_recipe} ===")
102+
89103
# We call the transformer stack "layers" in our TE models, but it's called "layer" in the original ESM-2 models.
90104
transformer_stack = model.esm.encoder.layers if hasattr(model.esm.encoder, "layers") else model.esm.encoder.layer
105+
91106
for layer in transformer_stack:
92107
fully_shard(layer, mesh=device_mesh["dp"])
93108
fully_shard(model, mesh=device_mesh["dp"])
94109

110+
# Assign names to layers so debug API can identify them - MUST be done BEFORE FSDP wrapping
111+
debug_api.infer_and_assign_layer_names(model)
112+
95113
# Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873).
96114
optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore
97115
scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs)
@@ -107,9 +125,9 @@ def main(args: DictConfig) -> float | None:
107125
else create_bshd_dataloader(dist_config, **args.dataset)
108126
)
109127

110-
if args.use_torch_compile:
111-
# If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency.
112-
model = torch.compile(model)
128+
# if args.use_torch_compile:
129+
# # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency.
130+
# model = torch.compile(model)
113131

114132
# If we're resuming from a checkpoint, load it and set the start step. Otherwise, start from step 0.
115133
ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_fsdp2" if args.checkpoint.ckpt_dir else None
@@ -128,15 +146,6 @@ def main(args: DictConfig) -> float | None:
128146

129147
perf_logger = PerfLogger(dist_config, args)
130148

131-
# TE Debug feature logging
132-
debug_api.initialize(
133-
config_file="/workspaces/bionemo-framework/bionemo-recipes/recipes/esm2_native_te/fp8_stats.yaml",
134-
feature_dirs=["/usr/local/lib/python3.12/dist-packages/transformer_engine/debug/features/"],
135-
log_dir="./log",
136-
default_logging_enabled=True
137-
)
138-
139-
140149
# Training loop
141150
step = start_step
142151
while step < args.num_train_steps:
@@ -159,7 +168,7 @@ def main(args: DictConfig) -> float | None:
159168
scheduler.step()
160169

161170
debug_api.step()
162-
171+
163172
optimizer.zero_grad()
164173

165174
perf_logger.log_step(
@@ -183,7 +192,6 @@ def main(args: DictConfig) -> float | None:
183192
max_checkpoints=args.checkpoint.max_checkpoints,
184193
)
185194

186-
187195
step += 1
188196
if step >= args.num_train_steps:
189197
break

0 commit comments

Comments
 (0)