Skip to content

【Hackathon 8th No.3】 clean oldIR in engine.py -part #71245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 40 additions & 211 deletions python/paddle/distributed/auto_parallel/static/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import logging
import numbers
import os
import random
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
Expand All @@ -44,10 +43,9 @@
)
from paddle.metric import Metric
from paddle.static import InputSpec, Operator, Variable, global_scope
from paddle.static.amp.fp16_utils import _convert_float_to_bfloat16

from ...utils.log_utils import get_logger
from ..interface import CollectionNames, fetch, get_collection
from ..interface import CollectionNames, get_collection
from ..static.dist_tensor import DistributedTensor
from ..strategy import Strategy
from .callbacks import config_callbacks
Expand Down Expand Up @@ -75,7 +73,7 @@
remove_unuseful_comm_op_pass,
)
from .planner_v2 import Planner
from .process_group import get_all_process_groups, new_process_group
from .process_group import get_all_process_groups
from .utils import set_all_ops_op_role

if TYPE_CHECKING:
Expand Down Expand Up @@ -576,37 +574,8 @@ def _prepare_fetch(self, user_fetches, mode):
# TODO(2024-Q2)
if self._in_pir_mode:
return fetch_names, fetch_indices

def _process_fetch_group(group_name, var_list):
group_indices = []
for var in var_list:
# Remove duplicate var_names
if self._is_local_var(var):
var_name = _to_name_str(var)
if var_name not in fetch_names:
fetch_names.append(var_name)
group_indices.append(fetch_names.index(var_name))
fetch_indices.append(group_indices)

dist_context = self._dist_contexts[mode]
fetch_vars = dist_context.serial_fetch_vars
if mode != "predict":
_process_fetch_group("loss", fetch_vars["loss"])
if mode != "predict":
metrics = fetch_vars["metrics"]
for i, var_list in enumerate(metrics):
_process_fetch_group("metrics_" + str(i), var_list)
if mode == "predict":
_process_fetch_group("outputs", fetch_vars["outputs"])
for usr_fetch in user_fetches or []:
var_name = _to_name_str(usr_fetch)
fetch(var_name)
user_fetches_collection = [
item[1] for item in get_collection(CollectionNames.FETCHES)
]
var_list = user_fetches_collection or []
_process_fetch_group("fetches", var_list)
return fetch_names, fetch_indices
else:
raise NotImplementedError("_prepare_fetch() only support PIR now.")

def _prepare_logger(
self,
Expand Down Expand Up @@ -1011,37 +980,23 @@ def _parallel_pir(self, mode):
self._pir_dist_startup_progs[mode] = startup_program

def _prepare_program(self, mode, init_parameters=True):
if self._in_pir_mode:
with paddle.amp.auto_cast(
enable=self._strategy.amp.enable,
custom_white_list=self._strategy.amp.custom_white_list,
custom_black_list=self._strategy.amp.custom_black_list,
level=self._strategy.amp.level,
dtype=self._strategy.amp.dtype,
use_promote=self._strategy.amp.use_promote,
):
self._build(mode)
self._parallel_pir(mode)
# Init comm
self._init_comm()
# startup program
self._initialize(mode, init_parameters)
self._has_prepared[mode] = True
return

# legacy program
# Do the build process
self._build(mode)
# Do the planning process
self._plan(mode)
# Do the parallel process
self._parallel(mode)
if not self._in_pir_mode:
raise NotImplementedError("prepare_program() only support PIR now.")

with paddle.amp.auto_cast(
enable=self._strategy.amp.enable,
custom_white_list=self._strategy.amp.custom_white_list,
custom_black_list=self._strategy.amp.custom_black_list,
level=self._strategy.amp.level,
dtype=self._strategy.amp.dtype,
use_promote=self._strategy.amp.use_promote,
):
self._build(mode)
self._parallel_pir(mode)
# Init comm
self._init_comm()
# startup program
self._initialize(mode, init_parameters)
# mark main program for further decompose
self._mark_prim(mode)
self._has_prepared[mode] = True

def _process_dist_input_specs(self):
Expand Down Expand Up @@ -1161,61 +1116,8 @@ def _build(self, mode):
self._fwd_main_progs[mode] = serial_main_prog
self._startup_progs[mode] = serial_startup_prog
return

default_ctx = get_default_distributed_context()
if not default_ctx.has_annotation:
# We build the world process group because the data parallel
# needs all ranks by default.
new_process_group(list(range(self._nranks)))
default_ctx.data_parallel = True
self._inputs = [
auto_utils.set_data_parallel(var) for var in self._inputs
]
self._labels = [
auto_utils.set_data_parallel(var) for var in self._labels
]

feed_vars = {"inputs": self._inputs, "labels": self._labels}

fetch_vars = {
"outputs": paddle.utils.flatten(outputs),
"loss": self._losses,
"metrics": metrics,
}

if mode != "train":
serial_main_prog = serial_main_prog.clone(for_test=True)

auto_utils.set_recompute_segments(
self._model, self._losses, self._strategy, serial_main_prog
)
self._dist_contexts[mode] = DistributedContext(
serial_main_prog,
serial_startup_prog,
self._optimizer,
self._losses,
feed_vars,
fetch_vars,
self._cluster,
self._strategy,
self._json_config,
)
self._fwd_dist_contexts[mode] = DistributedContext(
serial_main_prog,
serial_startup_prog,
self._optimizer,
self._losses,
feed_vars,
fetch_vars,
self._cluster,
self._strategy,
self._json_config,
)
self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale
self._dist_contexts[mode].gradient_scale_using_allreduce_avg = (
self._strategy.gradient_scale_using_allreduce_avg
)
self._fwd_main_progs[mode] = serial_main_prog.clone()
else:
raise NotImplementedError("_build() only support PIR now.")

def _optimization_tuning(self, mode, dataset, batch_size):
if not self._tuning.enable:
Expand Down Expand Up @@ -1325,18 +1227,8 @@ def _init_comm(self):
for process_group in all_process_groups:
process_group.instantiate()
return

# Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping.
all_process_groups = get_all_process_groups()

if self._strategy.auto_mode == "full_random":
auto_utils.initialize_pg_in_full_mode(
all_process_groups, self._cur_rank
)
else:
for process_group in all_process_groups:
process_group.instantiate()
raise NotImplementedError("_init_comm() only support PIR now.")

def _init_lr(self, main_program):
# hack to find learning_rate op
Expand Down Expand Up @@ -1456,83 +1348,8 @@ def _initialize(self, mode, init_parameters=True):
self._executor._set_plan(self._job_plan)
return

if self._strategy.seed:
paddle.seed(self._strategy.seed + self._dp_ranks[0])
np.random.seed(self._strategy.seed + self._dp_ranks[0])
random.seed(self._strategy.seed + self._dp_ranks[0])

dist_context = self._dist_contexts[mode]
dist_main_program = dist_context.dist_main_programs[self._cur_rank]
if self._dygraph_mode:
self.program_helper.init(
dist_main_program, self._place, dist_context
)
# The model's instance variables (not parameters), used in forward function,
# have been initialized when initialize model in dynamic mode.
if self._model and len(self._model.buffers()) > 0:
for buffer in self._model.buffers():
if dist_main_program.global_block().has_var(buffer.name):
dest_type = (
dist_main_program.global_block()
.var(buffer.name)
.dtype
)
scope_var = global_scope().find_var(buffer.name)
buffer_tensor = (
global_scope().var(buffer.name).get_tensor()
)
if scope_var and buffer_tensor._is_initialized():
continue
# for amp
if dest_type == paddle.bfloat16:
buffer_tensor.set(
_convert_float_to_bfloat16(
self._place, buffer.numpy()
),
self._place,
)
elif dest_type == paddle.float16:
buffer_tensor.set(
np.float16(buffer.numpy()), self._place
)
else:
buffer_tensor.set(buffer.numpy(), self._place)

if self._executor is None:
self._executor = paddle.static.Executor(self._place)
uninitialized = []
dist_startup_prog = dist_context.dist_startup_programs[
self._cur_rank
]
for var in dist_startup_prog.list_vars():
scope_var = global_scope().find_var(var.name)
if scope_var and scope_var.get_tensor()._is_initialized():
continue
uninitialized.append(var)
# Make sure the number of communication operators is consistent
commu_ops = []
if self._nranks > 1:
for op in dist_startup_prog.global_block().ops:
if auto_utils.is_comm_op(op):
commu_ops.append(op)
reserved_vars_and_ops = uninitialized + commu_ops
if reserved_vars_and_ops:
prune_startup_prog = dist_startup_prog._prune(
reserved_vars_and_ops
)
self._executor.run(prune_startup_prog)

if hasattr(self, "_state_dict") and hasattr(self, "_dist_attr"):
self._set_state_dict(
mode, self._strict, self._state_dict, self._dist_attr
)

if self._strategy.reinit:
self._logger.info("NOTE: parameters will be re-initialized.")
dist_startup_prog = dist_context.dist_startup_programs[
self._cur_rank
]
self._executor.run(dist_startup_prog)
else:
raise NotImplementedError("_initialize() only support PIR now.")

# distributed training combined with prim mechanism (prim is behind of distributed)
# for local main subprogram after distributed partition,
Expand Down Expand Up @@ -2555,29 +2372,41 @@ def cost(
def get_dist_main_program(self, mode: _Mode) -> Program:
if self._in_pir_mode:
return self._pir_dist_main_progs[self._mode]
return self._dist_contexts[mode].dist_main_programs[self._cur_rank]
else:
raise NotImplementedError(
"get_dist_main_program() only support PIR now."
)

def get_dist_startup_program(self, mode: _Mode) -> Program:
if self._in_pir_mode:
return self._pir_dist_startup_progs[self._mode]
return self._dist_contexts[mode].dist_startup_programs[self._cur_rank]
else:
raise NotImplementedError(
"get_dist_startup_program() only support PIR now."
)

def get_serial_main_program(self, mode: _Mode) -> Program:
if self._in_pir_mode:
return self._fwd_main_progs[mode]
return self._dist_contexts[mode].serial_main_program
else:
raise NotImplementedError(
"get_serial_main_program() only support PIR now."
)

def get_serial_startup_program(self, mode: _Mode) -> Program:
if self._in_pir_mode:
return self._startup_progs[mode]
return self._dist_contexts[mode].serial_startup_program
else:
raise NotImplementedError(
"get_serial_startup_program() only support PIR now."
)

@property
def main_program(self) -> Program:
if self._in_pir_mode:
return self._pir_dense_main_progs[self._mode]
dist_context = self._dist_contexts[self._mode]
return dist_context.dist_main_programs[self._cur_rank]
else:
raise NotImplementedError("main_program() only support PIR now.")

@property
def startup_program(self) -> Program:
Expand Down
23 changes: 0 additions & 23 deletions test/auto_parallel/hybrid_strategy/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,6 @@ if((WITH_GPU) AND (LINUX))
set_tests_properties(test_semi_auto_parallel_hybrid_strategy
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=HYBRID")
endif()
if((WITH_GPU) AND (LINUX))
py_test_modules(
test_semi_auto_parallel_llama_model MODULES
test_semi_auto_parallel_llama_model ENVS
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_semi_auto_parallel_llama_model
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=HYBRID")
endif()
if((WITH_GPU) AND (LINUX))
py_test_modules(
test_save_load_state_dict MODULES test_save_load_state_dict ENVS
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_save_load_state_dict
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=HYBRID")
endif()
if((WITH_GPU) AND (LINUX))
py_test_modules(
test_semi_auto_parallel_c_cross_entropy MODULES
Expand Down Expand Up @@ -81,14 +66,6 @@ if((WITH_GPU) AND (LINUX))
set_tests_properties(test_semi_auto_parallel_multi_inputs
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=HYBRID")
endif()
if((WITH_GPU) AND (LINUX))
py_test_modules(
test_semi_auto_parallel_llama_model_vpp MODULES
test_semi_auto_parallel_llama_model_vpp ENVS
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_semi_auto_parallel_llama_model_vpp
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=HYBRID")
endif()
if((WITH_GPU) AND (LINUX))
py_test_modules(
test_semi_auto_parallel_llama_model_pir
Expand Down
Loading
Loading