File tree Expand file tree Collapse file tree 2 files changed +14
-10
lines changed
python/paddle/distributed/auto_parallel/static Expand file tree Collapse file tree 2 files changed +14
-10
lines changed Original file line number Diff line number Diff line change 42
42
get_pp_stage_by_pp_degree ,
43
43
get_pp_stage_by_process_mesh ,
44
44
get_sub_process_mesh_by_program ,
45
+ partition_skip_op_list ,
45
46
)
46
47
47
48
_logger = get_logger (
50
51
51
52
register_reshard_funcs ()
52
53
53
- partition_skip_op_list = [
54
- "builtin.combine" ,
55
- "builtin.split" ,
56
- "pd_op.pylayer" ,
57
- "cf.yield" ,
58
- "cf.tuple_push" ,
59
- "cf.tuple_pop" ,
60
- "cf.stack_create" ,
61
- ]
62
-
63
54
amp_ops = ["pd_op.check_finite_and_unscale_" , "pd_op.update_loss_scaling_" ]
64
55
65
56
Original file line number Diff line number Diff line change 56
56
"reduce_sum" ,
57
57
]
58
58
59
+ partition_skip_op_list = [
60
+ "builtin.combine" ,
61
+ "builtin.split" ,
62
+ "pd_op.pylayer" ,
63
+ "cf.yield" ,
64
+ "cf.tuple_push" ,
65
+ "cf.tuple_pop" ,
66
+ "cf.stack_create" ,
67
+ ]
68
+
59
69
60
70
def get_logger (log_level , name = "auto_parallel" ):
61
71
logger = logging .getLogger (name )
@@ -1097,6 +1107,9 @@ def _complete_op_dist_attr(program, block=None):
1097
1107
for op in block .ops :
1098
1108
for sub_block in op .blocks ():
1099
1109
_complete_op_dist_attr (program , block = sub_block )
1110
+ if op .name () in partition_skip_op_list :
1111
+ continue
1112
+
1100
1113
if op .dist_attr is None :
1101
1114
meshes = []
1102
1115
operand_attrs = []
You can’t perform that action at this time.
0 commit comments