Skip to content

Commit a384b53

Browse files
committed
fix(megatron): patch validate_args and add ROCM argument validation
Hook Megatron validate_args alongside parse_args so Primus-injected arguments are validated consistently, and run additional ROCM-specific argument checks during initialization.
1 parent 415f11b commit a384b53

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

primus/backends/megatron/megatron_base_trainer.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,35 @@ def init(self):
2323
# log_dict_aligned("Backend arguments", self.backend_args)
2424

2525
def _patch_parse_args(self):
26-
"""Patch Megatron's parse_args to return pre-configured Primus arguments."""
26+
"""
27+
This function patches Megatron's parse_args to return pre-configured Primus arguments.
28+
It also validates the arguments on ROCM.
29+
"""
2730
import megatron.training.arguments as megatron_args # type: ignore
2831
import megatron.training.initialize as megatron_init # type: ignore
32+
from primus.modules.trainer.megatron.utils import validate_args_on_rocm
2933

3034
log_rank_0("Patching Megatron-LM parse_args()")
3135

3236
patched_parse_args = lambda *args, **kwargs: (
3337
log_rank_0("parse_args() called; returning Primus arguments") or self.backend_args
3438
)
3539

40+
original_validate_args = megatron_args.validate_args
41+
def patched_validate_args(*args, **kwargs):
42+
validated_args = original_validate_args(*args, **kwargs)
43+
parsed_args = args[0] if args else kwargs.get("args", None)
44+
if parsed_args is not None:
45+
log_rank_0("validate_args() called; validating on ROCM")
46+
validate_args_on_rocm(parsed_args)
47+
return validated_args
48+
3649
megatron_args.parse_args = patched_parse_args
3750
megatron_init.parse_args = patched_parse_args
3851

39-
log_rank_0(f"Patched parse_args(); Primus provided {len(vars(self.backend_args))} arguments")
52+
megatron_args.validate_args = patched_validate_args
53+
megatron_init.validate_args = patched_validate_args
54+
55+
log_rank_0(
56+
f"Patched parse_args()/validate_args(); Primus provided {len(vars(self.backend_args))} arguments"
57+
)

0 commit comments

Comments
 (0)