Skip to content

【Hackathon 8th No.3】 clean oldIR for pipeline scheduler -part #71246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from ..pass_utils import (
AutoParallelStreamType,
_add_event_dependency,
_program_for_fthenb_and_1f1b,
forward_complete_op_role,
split_program,
)
Expand Down Expand Up @@ -55,178 +54,13 @@ def __init__(self):
]
self.set_attr("enable_backward_forward_overlap", 0)

# Backward-forward overlapping splits and rearranges jobs for pattern Bi-Fj.
# For example: jobs = {..., BACKWARD-i, FORWARD-j, ...}, i < j
# BACKWARD-i: Calc1 - Comm1 - Calc2 - Comm2 - Calc3
# FORWARD-j: Calc4 - Comm3 - Calc5 - Comm4 - Calc6
# Timeline:
# ===Calc1==Comm1==Calc2==Comm2==Calc3==Calc4==Comm3==Calc5==Comm4==Calc6===
#
# After backward-forward overlapping: jobs = {Calc1, Comm1, Calc4, Comm3, Calc2, Comm2, Calc5, Comm4, Calc3, Calc6}
# Timeline:
# ===Calc1==Calc4==Calc2==Calc5==Calc3=Calc6===
# \ / \ /
# \ / \ /
# ==========Comm1==Comm3==Comm2==Comm4==========
#
def _backward_forward_overlap(self, backward_program, forward_program):
logger.info("Backward forward overlap enabled in 1F1B.")
# Split program
backward_ops, forward_ops = (
backward_program.global_block().ops,
forward_program.global_block().ops,
)
num_backward_ops, num_forward_ops = len(backward_ops), len(forward_ops)
backward_split_points, forward_split_points = [], []
backward_op_id, forward_op_id = 0, 0

while (
backward_op_id < num_backward_ops
and forward_op_id < num_forward_ops
):
# TODO(Ruibiao): Constrain the number of valid comm ops to resolve the potential memory explosion issue.
while (
backward_op_id < num_backward_ops
and not self.is_comm_op_valid_to_overlap(
backward_ops[backward_op_id]
)
):
backward_op_id += 1

if backward_op_id >= num_backward_ops:
break

backward_op_to_overlap = backward_ops[backward_op_id]
backward_cost_to_overlap = 400
backward_op_id += 1

forward_op_to_overlap = forward_ops[forward_op_id]
forward_cost_to_overlap = self._op_cost(forward_op_to_overlap)
'''
# Debug messages:
logger.info(
f"backward_op_to_overlap : {backward_op_to_overlap}, cost = {backward_cost_to_overlap}"
)
logger.info(
f"forward_op_to_overlap : {forward_op_to_overlap}, cost = {forward_cost_to_overlap}"
)
'''

while (
forward_op_id < num_forward_ops
and backward_cost_to_overlap >= forward_cost_to_overlap
):
forward_op_id += 1
forward_op_to_overlap = forward_ops[forward_op_id]
forward_cost_to_overlap += self._op_cost(forward_op_to_overlap)
'''
# Debug messages:
logger.info(
f"forward_op_to_overlap : {forward_op_to_overlap}, cost = {self._op_cost(forward_op_to_overlap)}"
)
'''

if self.is_comm_op_valid_to_overlap(
forward_ops[forward_op_id - 1]
):
break

if (
not forward_split_points
or forward_op_id > forward_split_points[-1]
):
backward_split_points.append(backward_op_id)
forward_split_points.append(forward_op_id)

(
splitted_backward_job_types,
splitted_backward_programs,
) = self._split_program_for_overlapping(
BACKWARD, backward_program, backward_split_points
)
(
splitted_forward_job_types,
splitted_forward_programs,
) = self._split_program_for_overlapping(
FORWARD, forward_program, forward_split_points
)

self._multistreaming_for_overlapping(
splitted_backward_programs, BACKWARD
)
self._multistreaming_for_overlapping(splitted_forward_programs, FORWARD)

# Rearrange splitted chunks for BACKWARD and FORWARD
self.jobs_in_stable_phase.clear()
num_splitted_backward_jobs, num_splitted_forward_jobs = len(
splitted_backward_job_types
), len(splitted_forward_job_types)
for idx in range(
max(num_splitted_backward_jobs, num_splitted_forward_jobs)
):
if idx < num_splitted_backward_jobs:
self.jobs_in_stable_phase.append(
splitted_backward_job_types[idx]
)
if idx < num_splitted_forward_jobs:
self.jobs_in_stable_phase.append(
splitted_forward_job_types[idx]
)

return (
splitted_backward_job_types,
splitted_backward_programs,
splitted_forward_job_types,
splitted_forward_programs,
)

def _create_job_list(self):
if self._in_pir_mode:
return self._create_job_list_in_pir()

num_micro_batches = self.get_attr("num_micro_batches")
pp_stage = self.get_attr("pp_stage")
pp_degree = self.get_attr("pp_degree")

job_list = []
assert (
pp_degree <= num_micro_batches
), "Num of micro batches should larger than or equal to pp degree."

micro_batch_in_warmup = pp_degree - pp_stage
micro_batch_in_1f1b = num_micro_batches - micro_batch_in_warmup

forward_micro_batch_id = 0
for i in range(micro_batch_in_warmup):
forward_job = core.Job(FORWARD)
forward_job.set_micro_batch_id(forward_micro_batch_id)
job_list.append(forward_job)
forward_micro_batch_id += 1

backward_micro_batch_id = 0
for i in range(micro_batch_in_1f1b):
for job_type in self.jobs_in_stable_phase:
job = core.Job(job_type)
micro_batch_id = (
forward_micro_batch_id
if job_type.startswith(FORWARD)
else backward_micro_batch_id
)
job.set_micro_batch_id(micro_batch_id)
job_list.append(job)
forward_micro_batch_id += 1
backward_micro_batch_id += 1

for i in range(micro_batch_in_warmup):
backward_job = core.Job(BACKWARD)
backward_job.set_micro_batch_id(backward_micro_batch_id)
job_list.append(backward_job)
backward_micro_batch_id += 1

opt_job = core.Job(OPT)
opt_job.set_micro_batch_id(0)
job_list.append(opt_job)
return job_list
else:
raise NotImplementedError(
"_create_job_list() only support PIR now."
)

def _create_job_list_in_pir(self):
num_micro_batches = self.get_attr("num_micro_batches")
Expand Down Expand Up @@ -373,40 +207,7 @@ def _op_cost(self, op):
return 0.0

def _partial_programs(self, program):
# NOTE: The flag "enable_send_recv_overlap" may increase the reserved memory of GPUs.
enable_send_recv_overlap = self.get_attr("enable_send_recv_overlap")
types = [FORWARD, BACKWARD, OPT]
sub_programs = _program_for_fthenb_and_1f1b(
program, enable_send_recv_overlap
)

enable_backward_forward_overlap = self.get_attr(
"enable_backward_forward_overlap"
)

if enable_backward_forward_overlap:
logger.info("Backward forward overlap enabled in 1F1B.")
forward_program, backward_program = sub_programs[1], sub_programs[2]
(
splitted_backward_job_types,
splitted_backward_programs,
splitted_forward_job_types,
splitted_forward_programs,
) = self._backward_forward_overlap(
backward_program, forward_program
)
types += splitted_forward_job_types + splitted_backward_job_types
sub_programs += (
splitted_forward_programs + splitted_backward_programs
)

for i in range(len(types)):
logger.debug(
f"type = {types[i]}, sub_programs = {sub_programs[i]}\n"
)
logger.debug(f"jobs_in_stable_phase = {self.jobs_in_stable_phase}")

return types, sub_programs
raise NotImplementedError("pipeline_1f1b_pass() only support PIR now.")

def _partial_pir_programs(self, program):
enable_send_recv_overlap = self.get_attr("enable_send_recv_overlap")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from ...utils.log_utils import get_logger
from ..pass_base import register_pass
from ..pass_utils import (
_program_for_fthenb_and_1f1b,
_split_program_into_forward_backward_optimize,
)
from .pipeline_pass_base import PipelinePassBase
Expand Down Expand Up @@ -57,13 +56,9 @@ def _create_job_list(self):
return job_list

def _partial_programs(self, program):
# NOTE: The flag "enable_send_recv_overlap" may increase the reserved memory of GPUs.
enable_send_recv_overlap = self.get_attr("enable_send_recv_overlap")
types = [FORWARD, BACKWARD, OPT]
sub_program_list = _program_for_fthenb_and_1f1b(
program, enable_send_recv_overlap
raise NotImplementedError(
"pipeline_fthenb_pass() only support PIR now."
)
return types, sub_program_list

def _partial_pir_programs(self, program):
# NOTE: The flag "enable_send_recv_overlap" may increase the reserved memory of GPUs.
Expand Down
Loading