Skip to content

Commit ef12f16

Browse files
wdykasrootrootrootroot
authored
Prep for refit (NVIDIA#2590)
Co-authored-by: root <[email protected]> Co-authored-by: root <[email protected]> Co-authored-by: root <[email protected]> Co-authored-by: root <[email protected]> Co-authored-by: root <[email protected]> Co-authored-by: root <[email protected]> Co-authored-by: root <[email protected]> Co-authored-by: root <[email protected]>
1 parent bd32927 commit ef12f16

File tree

14 files changed

+173
-48
lines changed

14 files changed

+173
-48
lines changed

gpt_builders.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# NOTE: Loading `megatron.legacy.model` earlier fails due to circular import
2222

2323

24-
def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None):
24+
def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_collection=None):
2525
print_rank_0('building GPT model ...')
2626
if config is None:
2727
if args.yaml_cfg is not None:
@@ -93,6 +93,7 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None):
9393
rope_scaling=args.use_rope_scaling,
9494
mtp_block_spec=mtp_block_spec,
9595
vp_stage=vp_stage,
96+
pg_collection=pg_collection,
9697
)
9798

9899
return model

mamba_builders.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from megatron.training.arguments import core_transformer_config_from_args
99
from megatron.core.models.mamba.mamba_layer_specs import mamba_inference_stack_spec
1010

11-
def mamba_builder(args, pre_process, post_process, vp_stage=None, config=None):
11+
def mamba_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_collection=None):
1212
print_rank_0('building MAMBA model ...')
1313
if config is None:
1414
config = core_transformer_config_from_args(args, TransformerConfig)
@@ -37,6 +37,7 @@ def mamba_builder(args, pre_process, post_process, vp_stage=None, config=None):
3737
position_embedding_type=args.position_embedding_type,
3838
rotary_percent=args.rotary_percent,
3939
rotary_base=args.rotary_base,
40+
pg_collection=pg_collection,
4041
)
4142

4243
for l in range(model.decoder.num_layers_per_pipeline_rank):

megatron/core/inference/engines/dynamic_engine.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from torch import Tensor
2020
from torch.cuda.nvtx import range_pop, range_push
2121

22-
from megatron.core import parallel_state
2322
from megatron.core.inference.contexts.dynamic_context import (
2423
DynamicInferenceContext,
2524
MaxSequenceLengthOverflowError,
@@ -40,8 +39,16 @@
4039
TextGenerationController,
4140
)
4241
from megatron.core.inference.utils import Counter, await_process_event
42+
from megatron.core.process_groups_config import ProcessGroupCollection
4343
from megatron.core.transformer.cuda_graphs import delete_cuda_graphs
44-
from megatron.core.utils import get_asyncio_loop, internal_api, trace_async_exceptions
44+
from megatron.core.utils import (
45+
get_asyncio_loop,
46+
get_pg_rank,
47+
get_pg_size,
48+
get_pg_src_rank,
49+
internal_api,
50+
trace_async_exceptions,
51+
)
4552

4653
try:
4754
from tqdm import tqdm
@@ -136,6 +143,7 @@ def __init__(
136143
track_paused_request_events: bool = False,
137144
enable_chunked_prefill: bool = True,
138145
inference_logging_step_interval: int = 0,
146+
pg_collection: Optional[ProcessGroupCollection] = None,
139147
):
140148

141149
assert isinstance(
@@ -159,6 +167,11 @@ def __init__(
159167
controller.inference_wrapped_model.model.config.enable_cuda_graph
160168
)
161169

170+
if pg_collection is not None:
171+
self.pg_collection = pg_collection
172+
else:
173+
self.pg_collection = ProcessGroupCollection.use_mpu_process_groups()
174+
162175
# Initialization options.
163176
self.controller = controller
164177
self.context = context
@@ -378,15 +391,15 @@ async def start_listening_to_data_parallel_coordinator(
378391
self.zmq_sockets = [] # keep track of all sockets created by this engine
379392

380393
# Get world info.
381-
dp_group = parallel_state.get_data_parallel_group()
382-
dp_src = parallel_state.get_data_parallel_src_rank()
383-
dp_size = parallel_state.get_data_parallel_world_size()
384-
dp_rank = parallel_state.get_data_parallel_rank()
394+
dp_group = self.pg_collection.dp
395+
dp_src = get_pg_src_rank(dp_group)
396+
dp_size = get_pg_size(self.pg_collection.dp)
397+
dp_rank = get_pg_rank(self.pg_collection.dp)
385398

386-
mp_group = parallel_state.get_model_parallel_group()
387-
mp_src = parallel_state.get_model_parallel_src_rank()
388-
tp_rank = parallel_state.get_tensor_model_parallel_rank()
389-
pp_rank = parallel_state.get_pipeline_model_parallel_rank()
399+
mp_group = self.pg_collection.mp
400+
mp_src = get_pg_src_rank(mp_group)
401+
tp_rank = get_pg_rank(self.pg_collection.tp)
402+
pp_rank = get_pg_rank(self.pg_collection.pp)
390403

391404
self.is_mp_coordinator = tp_rank == 0 and pp_rank == 0
392405
self.is_dp_coordinator = (dp_rank == 0) and self.is_mp_coordinator
@@ -400,7 +413,7 @@ async def start_listening_to_data_parallel_coordinator(
400413
args=(
401414
coordinator_ready_event,
402415
inference_coordinator_port,
403-
parallel_state.get_data_parallel_world_size(),
416+
get_pg_size(self.pg_collection.dp),
404417
),
405418
)
406419
self.inference_coordinator_process.start()

megatron/core/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,23 @@ def get_pg_rank(group=None):
555555
return group.rank()
556556

557557

558+
def get_pg_src_rank(group=None):
559+
"""Calculate the global rank corresponding to the first local rank
560+
in the given process group.
561+
562+
Args:
563+
group: Process group to query. If None or distributed is not initialized,
564+
returns 0.
565+
566+
Returns:
567+
int: The first (source) global rank in the group.
568+
"""
569+
if not torch.distributed.is_initialized() or group is None:
570+
return 0
571+
ranks = torch.distributed.get_process_group_ranks(group)
572+
return ranks[0]
573+
574+
558575
def get_attr_wrapped_model(model, attr, allow_none=True, return_model_obj=False):
559576
"""Get an attribute from a wrapped model.
560577
If return_model_obj is true, return the object that has the 'attr' attribute;

megatron/rl/inference/megatron.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pydantic import PrivateAttr
99

1010
from megatron.core import parallel_state
11+
from megatron.core.utils import get_attr_wrapped_model
1112
from megatron.core.inference.contexts.dynamic_context import DynamicInferenceContext
1213
from megatron.core.inference.engines.abstract_engine import AbstractEngine
1314
from megatron.core.inference.engines.dynamic_engine import DynamicInferenceEngine
@@ -26,7 +27,11 @@
2627
from megatron.core.models.gpt.gpt_model import GPTModel
2728
from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols
2829
from megatron.core.transformer.module import MegatronModule
29-
from megatron.core.utils import get_mamba_inference_state_config_from_model, log_single_rank
30+
from megatron.core.pipeline_parallel.utils import (
31+
is_pp_first_stage,
32+
is_pp_last_stage,
33+
)
34+
from megatron.core.utils import get_mamba_inference_state_config_from_model, log_single_rank, get_pg_size
3035
from megatron.training import get_wandb_writer
3136
from megatron.training.global_vars import get_args, get_tokenizer
3237

@@ -109,6 +114,16 @@ def get_dynamic_inference_engine(args: Namespace, model: MegatronModule, inferen
109114

110115
mamba_inference_state_config = get_mamba_inference_state_config_from_model(model)
111116

117+
# DynamicInferenceContext must use the inference model's TP size, not the
118+
# training TP size from global args. The inference model may have a custom
119+
# ProcessGroupCollection with a different TP size.
120+
pg_collection = get_attr_wrapped_model(model, "pg_collection")
121+
tp_group = getattr(pg_collection, 'tp', None) if pg_collection is not None else None
122+
if tp_group is not None:
123+
inference_tp_size = get_pg_size(tp_group)
124+
else:
125+
inference_tp_size = args.tensor_model_parallel_size
126+
112127
# Inference context.
113128
inference_context = DynamicInferenceContext(
114129
params_dtype=args.params_dtype,
@@ -126,7 +141,7 @@ def get_dynamic_inference_engine(args: Namespace, model: MegatronModule, inferen
126141
block_size_tokens=args.inference_dynamic_batching_block_size,
127142
buffer_size_gb=args.inference_dynamic_batching_buffer_size_gb,
128143
max_tokens=args.inference_dynamic_batching_max_tokens,
129-
tensor_model_parallel_size=args.tensor_model_parallel_size,
144+
tensor_model_parallel_size=inference_tp_size,
130145
materialize_only_last_token_logits=True,
131146
mamba_inference_state_config=mamba_inference_state_config,
132147
cache_mla_latent=args.multi_latent_attention and args.cache_mla_latents,
@@ -143,7 +158,7 @@ def get_dynamic_inference_engine(args: Namespace, model: MegatronModule, inferen
143158
inference_wrapped_model = GPTInferenceWrapper(model, args, inference_context)
144159

145160
inference_wrapped_model.model_is_pipeline_parallel = not (
146-
parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage()
161+
is_pp_first_stage(pg_collection.pp) and is_pp_last_stage(pg_collection.pp)
147162
)
148163

149164
text_generation_controller = SimpleTextGenerationController(
@@ -156,6 +171,7 @@ def get_dynamic_inference_engine(args: Namespace, model: MegatronModule, inferen
156171
enable_cuda_graph=enable_cuda_graph,
157172
random_seed=args.seed,
158173
inference_logging_step_interval=inference_logging_step_interval,
174+
pg_collection=pg_collection,
159175
)
160176

161177

megatron/training/training.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,18 @@
5252
from megatron.core.utils import (
5353
check_param_hashes_across_dp_replicas,
5454
get_model_config,
55+
get_pg_size,
56+
get_pg_rank,
5557
StragglerDetector,
5658
)
5759
from megatron.core.fp8_utils import correct_amax_history_if_needed
60+
from megatron.core.process_groups_config import ProcessGroupCollection
61+
from megatron.core.pipeline_parallel.utils import (
62+
is_pp_first_stage,
63+
is_pp_last_stage,
64+
is_vp_first_stage,
65+
is_vp_last_stage,
66+
)
5867
from megatron.training.checkpointing import load_checkpoint
5968
from megatron.training.checkpointing import save_checkpoint
6069
from megatron.training.checkpointing import checkpoint_exists
@@ -886,10 +895,12 @@ def update_train_iters(args):
886895
print_rank_0(f'setting training iterations to {args.train_iters}')
887896

888897

889-
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
898+
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True, config=None, pg_collection=None):
890899
"""Build the model."""
891900
args = get_args()
892901
args.model_type = model_type
902+
if pg_collection is None:
903+
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
893904

894905
if has_nvidia_modelopt:
895906
from megatron.post_training.checkpointing import has_modelopt_state
@@ -906,23 +917,38 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
906917
# Build model.
907918
def build_model():
908919
if (
909-
mpu.get_pipeline_model_parallel_world_size() > 1
920+
get_pg_size(pg_collection.pp) > 1
910921
and args.virtual_pipeline_model_parallel_size is not None
911922
):
912923
model = []
913-
for i in range(args.virtual_pipeline_model_parallel_size):
924+
vp_size = args.virtual_pipeline_model_parallel_size
925+
for i in range(vp_size):
914926
# Set pre_process and post_process only after virtual rank is set.
915-
pre_process = mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=i)
916-
post_process = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i)
927+
pre_process = is_pp_first_stage(pg_collection.pp) and is_vp_first_stage(
928+
vp_stage=i, vp_size=vp_size
929+
)
930+
post_process = is_pp_last_stage(pg_collection.pp) and is_vp_last_stage(
931+
vp_stage=i, vp_size=vp_size
932+
)
917933
this_model = model_provider_func(
918-
pre_process=pre_process, post_process=post_process, vp_stage=i)
934+
pre_process=pre_process,
935+
post_process=post_process,
936+
vp_stage=i,
937+
config=config,
938+
pg_collection=pg_collection,
939+
)
919940
this_model.model_type = model_type
920941
this_model.vp_stage = i
921942
model.append(this_model)
922943
else:
923-
pre_process = mpu.is_pipeline_first_stage()
924-
post_process = mpu.is_pipeline_last_stage()
925-
model = model_provider_func(pre_process=pre_process, post_process=post_process)
944+
pre_process = is_pp_first_stage(pg_collection.pp)
945+
post_process = is_pp_last_stage(pg_collection.pp)
946+
model = model_provider_func(
947+
pre_process=pre_process,
948+
post_process=post_process,
949+
config=config,
950+
pg_collection=pg_collection,
951+
)
926952
model.model_type = model_type
927953
return model
928954

@@ -947,12 +973,12 @@ def build_model():
947973
num_parameters = sum(
948974
[sum([p.nelement() for p in model_module.parameters()]) for model_module in model]
949975
)
950-
if mpu.get_data_parallel_rank() == 0 and mpu.get_context_parallel_rank() == 0:
976+
if get_pg_rank(pg_collection.dp) == 0 and get_pg_rank(pg_collection.cp) == 0:
951977
print(
952978
' > number of parameters on (tensor, pipeline) '
953979
'model parallel rank ({}, {}): {}'.format(
954-
mpu.get_tensor_model_parallel_rank(),
955-
mpu.get_pipeline_model_parallel_rank(),
980+
get_pg_rank(pg_collection.tp),
981+
get_pg_rank(pg_collection.pp),
956982
num_parameters,
957983
),
958984
flush=True,

model_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
def model_provider(
25-
model_builder: Callable, pre_process=True, post_process=True, vp_stage: Optional[int] = None
25+
model_builder: Callable, pre_process=True, post_process=True, vp_stage: Optional[int] = None, config=None, pg_collection=None,
2626
) -> Union[GPTModel, megatron.legacy.model.GPTModel, MambaModel]:
2727
"""Builds the model.
2828
@@ -64,7 +64,7 @@ def oom_observer(device, alloc, device_alloc, device_free):
6464
# [ModelOpt]: Use custom builder + spec when modelopt is enabled
6565
model_builder = modelopt_gpt_mamba_builder
6666

67-
return model_builder(args, pre_process, post_process, vp_stage)
67+
return model_builder(args, pre_process, post_process, vp_stage, config=config, pg_collection=pg_collection)
6868

6969

7070
def count_parameters_in_layer(model, layer_name):

pretrain_bert.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,14 @@
2828
from megatron.core.tokenizers import MegatronTokenizer
2929

3030

31-
def model_provider(pre_process=True, post_process=True, vp_stage=None):
31+
def model_provider(pre_process=True, post_process=True, vp_stage=None, config=None, pg_collection=None):
3232
"""Build the model."""
3333

3434
print_rank_0('building BERT model ...')
3535

3636
args = get_args()
37-
config = core_transformer_config_from_args(args)
37+
if config is None:
38+
config = core_transformer_config_from_args(args)
3839
num_tokentypes = 2 if args.bert_binary_head else 0
3940

4041
if args.use_legacy_models:

pretrain_t5.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,12 @@
6666

6767

6868
def model_provider(
69-
pre_process=True, post_process=True, add_encoder=True, add_decoder=True
69+
pre_process=True,
70+
post_process=True,
71+
add_encoder=True,
72+
add_decoder=True,
73+
config=None,
74+
pg_collection=None,
7075
) -> Union[megatron.legacy.model.T5Model, T5Model]:
7176
"""Builds the model.
7277
@@ -83,7 +88,8 @@ def model_provider(
8388

8489
args = get_args()
8590

86-
config = core_transformer_config_from_args(args)
91+
if config is None:
92+
config = core_transformer_config_from_args(args)
8793
if args.use_legacy_models:
8894
model = megatron.legacy.model.T5Model(
8995
config=config,

pretrain_vlm.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,13 @@
4343

4444

4545
def model_provider(
46-
pre_process=True, post_process=True, add_encoder=True, add_decoder=True, parallel_output=True
46+
pre_process=True,
47+
post_process=True,
48+
add_encoder=True,
49+
add_decoder=True,
50+
parallel_output=True,
51+
config=None,
52+
pg_collection=None,
4753
) -> LLaVAModel:
4854
"""Builds the model.
4955
@@ -100,7 +106,10 @@ def model_provider(
100106
args.max_position_embeddings = max(args.max_position_embeddings, args.decoder_seq_length)
101107

102108
print_rank_0('building a multimodal model ...')
103-
language_transformer_config = core_transformer_config_from_args(get_args())
109+
if config is None:
110+
language_transformer_config = core_transformer_config_from_args(get_args())
111+
else:
112+
language_transformer_config = config
104113
if args.decoder_num_layers is not None:
105114
language_transformer_config.num_layers = args.decoder_num_layers
106115
else:

0 commit comments

Comments
 (0)