1818from pathlib import Path
1919
2020import hydra
21+ import nvdlfw_inspect .api as debug_api
2122import torch
2223import transformer_engine .pytorch
2324from omegaconf import DictConfig , OmegaConf
3637from perf_logger import PerfLogger
3738from scheduler import get_linear_schedule_with_warmup
3839
39- import nvdlfw_inspect .api as debug_api
40-
4140
4241logger = logging .getLogger (__name__ )
4342logger .setLevel (logging .INFO )
@@ -86,12 +85,31 @@ def main(args: DictConfig) -> float | None:
8685
8786 logger .info ("Initialized Model:\n %s" , model )
8887
88+ # TE Debug feature logging - MUST be done BEFORE FSDP wrapping
89+ debug_api .initialize (
90+ config_file = "/workspaces/bionemo-framework/bionemo-recipes/recipes/esm2_native_te/fp8_stats_block_scaling.yaml" ,
91+ feature_dirs = ["/usr/local/lib/python3.12/dist-packages/transformer_engine/debug/features/" ],
92+ log_dir = "./log" ,
93+ default_logging_enabled = True ,
94+ )
95+ # Debug: Print module types to verify what we're working with
96+ if dist_config .local_rank == 0 :
97+ logger .info ("=== DEBUG: Module types in model ===" )
98+ for name , module in model .named_modules ():
99+ if 'layernorm_qkv' in name or 'proj' in name or 'self_attention' in name :
100+ logger .info (f" -----> { name } : { type (module )} <----" )
101+ logger .info (f"=== DEBUG: FP8 config enabled={ args .fp8_config .enabled } , recipe={ args .fp8_config .fp8_recipe } ===" )
102+
89103 # We call the transformer stack "layers" in our TE models, but it's called "layer" in the original ESM-2 models.
90104 transformer_stack = model .esm .encoder .layers if hasattr (model .esm .encoder , "layers" ) else model .esm .encoder .layer
105+
91106 for layer in transformer_stack :
92107 fully_shard (layer , mesh = device_mesh ["dp" ])
93108 fully_shard (model , mesh = device_mesh ["dp" ])
94109
110+ # Assign names to layers so debug API can identify them - MUST be done BEFORE FSDP wrapping
111+ debug_api .infer_and_assign_layer_names (model )
112+
95113 # Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873).
96114 optimizer = AdamW (model .parameters (), ** OmegaConf .to_container (args .adamw_kwargs , resolve = True )) # type: ignore
97115 scheduler = get_linear_schedule_with_warmup (optimizer , ** args .lr_scheduler_kwargs )
@@ -107,9 +125,9 @@ def main(args: DictConfig) -> float | None:
107125 else create_bshd_dataloader (dist_config , ** args .dataset )
108126 )
109127
110- if args .use_torch_compile :
111- # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency.
112- model = torch .compile (model )
128+ # if args.use_torch_compile:
129+ # # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency.
130+ # model = torch.compile(model)
113131
114132 # If we're resuming from a checkpoint, load it and set the start step. Otherwise, start from step 0.
115133 ckpt_path = Path (args .checkpoint .ckpt_dir ) / "train_fsdp2" if args .checkpoint .ckpt_dir else None
@@ -128,15 +146,6 @@ def main(args: DictConfig) -> float | None:
128146
129147 perf_logger = PerfLogger (dist_config , args )
130148
131- # TE Debug feature logging
132- debug_api .initialize (
133- config_file = "/workspaces/bionemo-framework/bionemo-recipes/recipes/esm2_native_te/fp8_stats.yaml" ,
134- feature_dirs = ["/usr/local/lib/python3.12/dist-packages/transformer_engine/debug/features/" ],
135- log_dir = "./log" ,
136- default_logging_enabled = True
137- )
138-
139-
140149 # Training loop
141150 step = start_step
142151 while step < args .num_train_steps :
@@ -159,7 +168,7 @@ def main(args: DictConfig) -> float | None:
159168 scheduler .step ()
160169
161170 debug_api .step ()
162-
171+
163172 optimizer .zero_grad ()
164173
165174 perf_logger .log_step (
@@ -183,7 +192,6 @@ def main(args: DictConfig) -> float | None:
183192 max_checkpoints = args .checkpoint .max_checkpoints ,
184193 )
185194
186-
187195 step += 1
188196 if step >= args .num_train_steps :
189197 break
0 commit comments