diff --git a/python/paddle/distributed/passes/pipeline_scheduler_pass/pipeline_1f1b.py b/python/paddle/distributed/passes/pipeline_scheduler_pass/pipeline_1f1b.py index 7d68aa9e6a6b38..2db0188147c9fe 100644 --- a/python/paddle/distributed/passes/pipeline_scheduler_pass/pipeline_1f1b.py +++ b/python/paddle/distributed/passes/pipeline_scheduler_pass/pipeline_1f1b.py @@ -26,7 +26,6 @@ from ..pass_utils import ( AutoParallelStreamType, _add_event_dependency, - _program_for_fthenb_and_1f1b, forward_complete_op_role, split_program, ) @@ -55,178 +54,13 @@ def __init__(self): ] self.set_attr("enable_backward_forward_overlap", 0) - # Backward-forward overlapping splits and rearranges jobs for pattern Bi-Fj. - # For example: jobs = {..., BACKWARD-i, FORWARD-j, ...}, i < j - # BACKWARD-i: Calc1 - Comm1 - Calc2 - Comm2 - Calc3 - # FORWARD-j: Calc4 - Comm3 - Calc5 - Comm4 - Calc6 - # Timeline: - # ===Calc1==Comm1==Calc2==Comm2==Calc3==Calc4==Comm3==Calc5==Comm4==Calc6=== - # - # After backward-forward overlapping: jobs = {Calc1, Comm1, Calc4, Comm3, Calc2, Comm2, Calc5, Comm4, Calc3, Calc6} - # Timeline: - # ===Calc1==Calc4==Calc2==Calc5==Calc3=Calc6=== - # \ / \ / - # \ / \ / - # ==========Comm1==Comm3==Comm2==Comm4========== - # - def _backward_forward_overlap(self, backward_program, forward_program): - logger.info("Backward forward overlap enabled in 1F1B.") - # Split program - backward_ops, forward_ops = ( - backward_program.global_block().ops, - forward_program.global_block().ops, - ) - num_backward_ops, num_forward_ops = len(backward_ops), len(forward_ops) - backward_split_points, forward_split_points = [], [] - backward_op_id, forward_op_id = 0, 0 - - while ( - backward_op_id < num_backward_ops - and forward_op_id < num_forward_ops - ): - # TODO(Ruibiao): Constrain the number of valid comm ops to resolve the potential memory explosion issue. - while ( - backward_op_id < num_backward_ops - and not self.is_comm_op_valid_to_overlap( - backward_ops[backward_op_id] - ) - ): - backward_op_id += 1 - - if backward_op_id >= num_backward_ops: - break - - backward_op_to_overlap = backward_ops[backward_op_id] - backward_cost_to_overlap = 400 - backward_op_id += 1 - - forward_op_to_overlap = forward_ops[forward_op_id] - forward_cost_to_overlap = self._op_cost(forward_op_to_overlap) - ''' - # Debug messages: - logger.info( - f"backward_op_to_overlap : {backward_op_to_overlap}, cost = {backward_cost_to_overlap}" - ) - logger.info( - f"forward_op_to_overlap : {forward_op_to_overlap}, cost = {forward_cost_to_overlap}" - ) - ''' - - while ( - forward_op_id < num_forward_ops - and backward_cost_to_overlap >= forward_cost_to_overlap - ): - forward_op_id += 1 - forward_op_to_overlap = forward_ops[forward_op_id] - forward_cost_to_overlap += self._op_cost(forward_op_to_overlap) - ''' - # Debug messages: - logger.info( - f"forward_op_to_overlap : {forward_op_to_overlap}, cost = {self._op_cost(forward_op_to_overlap)}" - ) - ''' - - if self.is_comm_op_valid_to_overlap( - forward_ops[forward_op_id - 1] - ): - break - - if ( - not forward_split_points - or forward_op_id > forward_split_points[-1] - ): - backward_split_points.append(backward_op_id) - forward_split_points.append(forward_op_id) - - ( - splitted_backward_job_types, - splitted_backward_programs, - ) = self._split_program_for_overlapping( - BACKWARD, backward_program, backward_split_points - ) - ( - splitted_forward_job_types, - splitted_forward_programs, - ) = self._split_program_for_overlapping( - FORWARD, forward_program, forward_split_points - ) - - self._multistreaming_for_overlapping( - splitted_backward_programs, BACKWARD - ) - self._multistreaming_for_overlapping(splitted_forward_programs, FORWARD) - - # Rearrange splitted chunks for BACKWARD and FORWARD - self.jobs_in_stable_phase.clear() - num_splitted_backward_jobs, num_splitted_forward_jobs = len( - splitted_backward_job_types - ), len(splitted_forward_job_types) - for idx in range( - max(num_splitted_backward_jobs, num_splitted_forward_jobs) - ): - if idx < num_splitted_backward_jobs: - self.jobs_in_stable_phase.append( - splitted_backward_job_types[idx] - ) - if idx < num_splitted_forward_jobs: - self.jobs_in_stable_phase.append( - splitted_forward_job_types[idx] - ) - - return ( - splitted_backward_job_types, - splitted_backward_programs, - splitted_forward_job_types, - splitted_forward_programs, - ) - def _create_job_list(self): if self._in_pir_mode: return self._create_job_list_in_pir() - - num_micro_batches = self.get_attr("num_micro_batches") - pp_stage = self.get_attr("pp_stage") - pp_degree = self.get_attr("pp_degree") - - job_list = [] - assert ( - pp_degree <= num_micro_batches - ), "Num of micro batches should larger than or equal to pp degree." - - micro_batch_in_warmup = pp_degree - pp_stage - micro_batch_in_1f1b = num_micro_batches - micro_batch_in_warmup - - forward_micro_batch_id = 0 - for i in range(micro_batch_in_warmup): - forward_job = core.Job(FORWARD) - forward_job.set_micro_batch_id(forward_micro_batch_id) - job_list.append(forward_job) - forward_micro_batch_id += 1 - - backward_micro_batch_id = 0 - for i in range(micro_batch_in_1f1b): - for job_type in self.jobs_in_stable_phase: - job = core.Job(job_type) - micro_batch_id = ( - forward_micro_batch_id - if job_type.startswith(FORWARD) - else backward_micro_batch_id - ) - job.set_micro_batch_id(micro_batch_id) - job_list.append(job) - forward_micro_batch_id += 1 - backward_micro_batch_id += 1 - - for i in range(micro_batch_in_warmup): - backward_job = core.Job(BACKWARD) - backward_job.set_micro_batch_id(backward_micro_batch_id) - job_list.append(backward_job) - backward_micro_batch_id += 1 - - opt_job = core.Job(OPT) - opt_job.set_micro_batch_id(0) - job_list.append(opt_job) - return job_list + else: + raise NotImplementedError( + "_create_job_list() only support PIR now." + ) def _create_job_list_in_pir(self): num_micro_batches = self.get_attr("num_micro_batches") @@ -373,40 +207,7 @@ def _op_cost(self, op): return 0.0 def _partial_programs(self, program): - # NOTE: The flag "enable_send_recv_overlap" may increase the reserved memory of GPUs. - enable_send_recv_overlap = self.get_attr("enable_send_recv_overlap") - types = [FORWARD, BACKWARD, OPT] - sub_programs = _program_for_fthenb_and_1f1b( - program, enable_send_recv_overlap - ) - - enable_backward_forward_overlap = self.get_attr( - "enable_backward_forward_overlap" - ) - - if enable_backward_forward_overlap: - logger.info("Backward forward overlap enabled in 1F1B.") - forward_program, backward_program = sub_programs[1], sub_programs[2] - ( - splitted_backward_job_types, - splitted_backward_programs, - splitted_forward_job_types, - splitted_forward_programs, - ) = self._backward_forward_overlap( - backward_program, forward_program - ) - types += splitted_forward_job_types + splitted_backward_job_types - sub_programs += ( - splitted_forward_programs + splitted_backward_programs - ) - - for i in range(len(types)): - logger.debug( - f"type = {types[i]}, sub_programs = {sub_programs[i]}\n" - ) - logger.debug(f"jobs_in_stable_phase = {self.jobs_in_stable_phase}") - - return types, sub_programs + raise NotImplementedError("pipeline_1f1b_pass() only support PIR now.") def _partial_pir_programs(self, program): enable_send_recv_overlap = self.get_attr("enable_send_recv_overlap") diff --git a/python/paddle/distributed/passes/pipeline_scheduler_pass/pipeline_fthenb.py b/python/paddle/distributed/passes/pipeline_scheduler_pass/pipeline_fthenb.py index de42eab3ce3424..c591ccc2e11d28 100644 --- a/python/paddle/distributed/passes/pipeline_scheduler_pass/pipeline_fthenb.py +++ b/python/paddle/distributed/passes/pipeline_scheduler_pass/pipeline_fthenb.py @@ -19,7 +19,6 @@ from ...utils.log_utils import get_logger from ..pass_base import register_pass from ..pass_utils import ( - _program_for_fthenb_and_1f1b, _split_program_into_forward_backward_optimize, ) from .pipeline_pass_base import PipelinePassBase @@ -57,13 +56,9 @@ def _create_job_list(self): return job_list def _partial_programs(self, program): - # NOTE: The flag "enable_send_recv_overlap" may increase the reserved memory of GPUs. - enable_send_recv_overlap = self.get_attr("enable_send_recv_overlap") - types = [FORWARD, BACKWARD, OPT] - sub_program_list = _program_for_fthenb_and_1f1b( - program, enable_send_recv_overlap + raise NotImplementedError( + "pipeline_fthenb_pass() only support PIR now." ) - return types, sub_program_list def _partial_pir_programs(self, program): # NOTE: The flag "enable_send_recv_overlap" may increase the reserved memory of GPUs. diff --git a/test/standalone_executor/test_standalone_executor_1f1b_plan.py b/test/standalone_executor/test_standalone_executor_1f1b_plan.py deleted file mode 100644 index e40facbefe179d..00000000000000 --- a/test/standalone_executor/test_standalone_executor_1f1b_plan.py +++ /dev/null @@ -1,260 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -from paddle import base, static -from paddle.distributed.passes import PassContext, new_pass - - -class TestStandaloneExecutor1F1BPlan(unittest.TestCase): - def test_standalone_executor_1f1b_plan_stage0(self): - base.set_flags({'FLAGS_enable_pir_api': 0}) - config = {"num_micro_batches": 8, "pp_stage": 0, "pp_degree": 4} - pass_context = PassContext() - - startup_program = static.Program() - main_program = static.Program() - - pipeline_1f1b_pass = new_pass("pipeline_scheduler_1F1B", config) - pipeline_1f1b_pass.apply( - [main_program], [startup_program], pass_context - ) - plan = pass_context.get_attr("plan") - job_type_list = [] - micro_batch_id_list = [] - for job in plan.job_list(): - job_type_list.append(job.type()) - micro_batch_id_list.append(job.micro_batch_id()) - expect_job_type_list = [ - "forward", - "forward", - "forward", - "forward", - "backward", - "forward", - "backward", - "forward", - "backward", - "forward", - "backward", - "forward", - "backward", - "backward", - "backward", - "backward", - "optimizer", - ] - expect_micro_batch_id_list = [ - 0, - 1, - 2, - 3, - 0, - 4, - 1, - 5, - 2, - 6, - 3, - 7, - 4, - 5, - 6, - 7, - 0, - ] - self.assertEqual(job_type_list, expect_job_type_list) - self.assertEqual(micro_batch_id_list, expect_micro_batch_id_list) - - def test_standalone_executor_1f1b_plan_stage1(self): - base.set_flags({'FLAGS_enable_pir_api': 0}) - config = {"num_micro_batches": 8, "pp_stage": 1, "pp_degree": 4} - pass_context = PassContext() - - startup_program = static.Program() - main_program = static.Program() - - pipeline_1f1b_pass = new_pass("pipeline_scheduler_1F1B", config) - pipeline_1f1b_pass.apply( - [main_program], [startup_program], pass_context - ) - plan = pass_context.get_attr("plan") - job_type_list = [] - micro_batch_id_list = [] - for job in plan.job_list(): - job_type_list.append(job.type()) - micro_batch_id_list.append(job.micro_batch_id()) - expect_job_type_list = [ - "forward", - "forward", - "forward", - "backward", - "forward", - "backward", - "forward", - "backward", - "forward", - "backward", - "forward", - "backward", - "forward", - "backward", - "backward", - "backward", - "optimizer", - ] - expect_micro_batch_id_list = [ - 0, - 1, - 2, - 0, - 3, - 1, - 4, - 2, - 5, - 3, - 6, - 4, - 7, - 5, - 6, - 7, - 0, - ] - self.assertEqual(job_type_list, expect_job_type_list) - self.assertEqual(micro_batch_id_list, expect_micro_batch_id_list) - - def test_standalone_executor_1f1b_plan_stage2(self): - base.set_flags({'FLAGS_enable_pir_api': 0}) - config = {"num_micro_batches": 8, "pp_stage": 2, "pp_degree": 4} - pass_context = PassContext() - - startup_program = static.Program() - main_program = static.Program() - - pipeline_1f1b_pass = new_pass("pipeline_scheduler_1F1B", config) - pipeline_1f1b_pass.apply( - [main_program], [startup_program], pass_context - ) - plan = pass_context.get_attr("plan") - job_type_list = [] - micro_batch_id_list = [] - for job in plan.job_list(): - job_type_list.append(job.type()) - micro_batch_id_list.append(job.micro_batch_id()) - expect_job_type_list = [ - "forward", - "forward", - "backward", - "forward", - "backward", - "forward", - "backward", - "forward", - "backward", - "forward", - "backward", - "forward", - "backward", - "forward", - "backward", - "backward", - "optimizer", - ] - expect_micro_batch_id_list = [ - 0, - 1, - 0, - 2, - 1, - 3, - 2, - 4, - 3, - 5, - 4, - 6, - 5, - 7, - 6, - 7, - 0, - ] - self.assertEqual(job_type_list, expect_job_type_list) - self.assertEqual(micro_batch_id_list, expect_micro_batch_id_list) - - def test_standalone_executor_1f1b_plan_stage3(self): - base.set_flags({'FLAGS_enable_pir_api': 0}) - config = {"num_micro_batches": 8, "pp_stage": 3, "pp_degree": 4} - pass_context = PassContext() - - startup_program = static.Program() - main_program = static.Program() - - pipeline_1f1b_pass = new_pass("pipeline_scheduler_1F1B", config) - pipeline_1f1b_pass.apply( - [main_program], [startup_program], pass_context - ) - plan = pass_context.get_attr("plan") - job_type_list = [] - micro_batch_id_list = [] - for job in plan.job_list(): - job_type_list.append(job.type()) - micro_batch_id_list.append(job.micro_batch_id()) - expect_job_type_list = [ - "forward", - "backward", - "forward", - "backward", - "forward", - "backward", - "forward", - "backward", - "forward", - "backward", - "forward", - "backward", - "forward", - "backward", - "forward", - "backward", - "optimizer", - ] - expect_micro_batch_id_list = [ - 0, - 0, - 1, - 1, - 2, - 2, - 3, - 3, - 4, - 4, - 5, - 5, - 6, - 6, - 7, - 7, - 0, - ] - self.assertEqual(job_type_list, expect_job_type_list) - self.assertEqual(micro_batch_id_list, expect_micro_batch_id_list) - - -if __name__ == '__main__': - unittest.main()