Skip to content

Commit 91108c7

Browse files
authored
update (#69584)
1 parent 0aaaa8d commit 91108c7

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

python/paddle/distributed/auto_parallel/static/pir_pass.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
get_pp_stage_by_pp_degree,
4343
get_pp_stage_by_process_mesh,
4444
get_sub_process_mesh_by_program,
45+
partition_skip_op_list,
4546
)
4647

4748
_logger = get_logger(
@@ -50,16 +51,6 @@
5051

5152
register_reshard_funcs()
5253

53-
partition_skip_op_list = [
54-
"builtin.combine",
55-
"builtin.split",
56-
"pd_op.pylayer",
57-
"cf.yield",
58-
"cf.tuple_push",
59-
"cf.tuple_pop",
60-
"cf.stack_create",
61-
]
62-
6354
amp_ops = ["pd_op.check_finite_and_unscale_", "pd_op.update_loss_scaling_"]
6455

6556

python/paddle/distributed/auto_parallel/static/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,16 @@
5656
"reduce_sum",
5757
]
5858

59+
partition_skip_op_list = [
60+
"builtin.combine",
61+
"builtin.split",
62+
"pd_op.pylayer",
63+
"cf.yield",
64+
"cf.tuple_push",
65+
"cf.tuple_pop",
66+
"cf.stack_create",
67+
]
68+
5969

6070
def get_logger(log_level, name="auto_parallel"):
6171
logger = logging.getLogger(name)
@@ -1097,6 +1107,9 @@ def _complete_op_dist_attr(program, block=None):
10971107
for op in block.ops:
10981108
for sub_block in op.blocks():
10991109
_complete_op_dist_attr(program, block=sub_block)
1110+
if op.name() in partition_skip_op_list:
1111+
continue
1112+
11001113
if op.dist_attr is None:
11011114
meshes = []
11021115
operand_attrs = []

0 commit comments

Comments
 (0)