5252from megatron .core .utils import (
5353 check_param_hashes_across_dp_replicas ,
5454 get_model_config ,
55+ get_pg_size ,
56+ get_pg_rank ,
5557 StragglerDetector ,
5658)
5759from megatron .core .fp8_utils import correct_amax_history_if_needed
60+ from megatron .core .process_groups_config import ProcessGroupCollection
61+ from megatron .core .pipeline_parallel .utils import (
62+ is_pp_first_stage ,
63+ is_pp_last_stage ,
64+ is_vp_first_stage ,
65+ is_vp_last_stage ,
66+ )
5867from megatron .training .checkpointing import load_checkpoint
5968from megatron .training .checkpointing import save_checkpoint
6069from megatron .training .checkpointing import checkpoint_exists
@@ -886,10 +895,12 @@ def update_train_iters(args):
886895 print_rank_0 (f'setting training iterations to { args .train_iters } ' )
887896
888897
889- def get_model (model_provider_func , model_type = ModelType .encoder_or_decoder , wrap_with_ddp = True ):
898+ def get_model (model_provider_func , model_type = ModelType .encoder_or_decoder , wrap_with_ddp = True , config = None , pg_collection = None ):
890899 """Build the model."""
891900 args = get_args ()
892901 args .model_type = model_type
902+ if pg_collection is None :
903+ pg_collection = ProcessGroupCollection .use_mpu_process_groups ()
893904
894905 if has_nvidia_modelopt :
895906 from megatron .post_training .checkpointing import has_modelopt_state
@@ -906,23 +917,38 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
906917 # Build model.
907918 def build_model ():
908919 if (
909- mpu . get_pipeline_model_parallel_world_size ( ) > 1
920+ get_pg_size ( pg_collection . pp ) > 1
910921 and args .virtual_pipeline_model_parallel_size is not None
911922 ):
912923 model = []
913- for i in range (args .virtual_pipeline_model_parallel_size ):
924+ vp_size = args .virtual_pipeline_model_parallel_size
925+ for i in range (vp_size ):
914926 # Set pre_process and post_process only after virtual rank is set.
915- pre_process = mpu .is_pipeline_first_stage (ignore_virtual = False , vp_stage = i )
916- post_process = mpu .is_pipeline_last_stage (ignore_virtual = False , vp_stage = i )
927+ pre_process = is_pp_first_stage (pg_collection .pp ) and is_vp_first_stage (
928+ vp_stage = i , vp_size = vp_size
929+ )
930+ post_process = is_pp_last_stage (pg_collection .pp ) and is_vp_last_stage (
931+ vp_stage = i , vp_size = vp_size
932+ )
917933 this_model = model_provider_func (
918- pre_process = pre_process , post_process = post_process , vp_stage = i )
934+ pre_process = pre_process ,
935+ post_process = post_process ,
936+ vp_stage = i ,
937+ config = config ,
938+ pg_collection = pg_collection ,
939+ )
919940 this_model .model_type = model_type
920941 this_model .vp_stage = i
921942 model .append (this_model )
922943 else :
923- pre_process = mpu .is_pipeline_first_stage ()
924- post_process = mpu .is_pipeline_last_stage ()
925- model = model_provider_func (pre_process = pre_process , post_process = post_process )
944+ pre_process = is_pp_first_stage (pg_collection .pp )
945+ post_process = is_pp_last_stage (pg_collection .pp )
946+ model = model_provider_func (
947+ pre_process = pre_process ,
948+ post_process = post_process ,
949+ config = config ,
950+ pg_collection = pg_collection ,
951+ )
926952 model .model_type = model_type
927953 return model
928954
@@ -947,12 +973,12 @@ def build_model():
947973 num_parameters = sum (
948974 [sum ([p .nelement () for p in model_module .parameters ()]) for model_module in model ]
949975 )
950- if mpu . get_data_parallel_rank ( ) == 0 and mpu . get_context_parallel_rank ( ) == 0 :
976+ if get_pg_rank ( pg_collection . dp ) == 0 and get_pg_rank ( pg_collection . cp ) == 0 :
951977 print (
952978 ' > number of parameters on (tensor, pipeline) '
953979 'model parallel rank ({}, {}): {}' .format (
954- mpu . get_tensor_model_parallel_rank ( ),
955- mpu . get_pipeline_model_parallel_rank ( ),
980+ get_pg_rank ( pg_collection . tp ),
981+ get_pg_rank ( pg_collection . pp ),
956982 num_parameters ,
957983 ),
958984 flush = True ,
0 commit comments