Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 101 additions & 63 deletions onnxruntime/core/optimizer/layer_norm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,15 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}
Node& div_node = *graph.GetNode(p_div->Index());

// If Div has 2+ output return layer_norm_input_defs with scale = 1 and bias = nullptr;
bool skip_scale_bias_for_layer_norm = optimizer_utils::CheckOutputEdges(graph, div_node, 1) ? false : true;

if (!graph_utils::IsSupportedOptypeVersionAndDomain(div_node, "Div", {7, 13, 14}) ||
div_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, div_node, 1) ||
!IsSupportedDataType(div_node)) {
continue;
}
nodes_to_remove.push_back(div_node);

// Traceback the div node to find sqrt --> div
const Node* p_sqrt = graph_utils::FirstParentByType(div_node, "Sqrt");
Expand Down Expand Up @@ -366,39 +368,15 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}
}

// Apex O2 pattern specific match starts...
// Logically since we support input and scale/bias in different data types, those Cast Ops in sub-graph
// can be removed. This is one possible place a Cast Op can exist, that is between Div and Mul nodes.
// div --> mul or div --> cast --> mul
Node* next_node = graph.GetNode(div_node.OutputNodesBegin()->Index());
if (graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Cast", {9, 13, 19}) &&
optimizer_utils::CheckOutputEdges(graph, *next_node, 1)) {
nodes_to_remove.push_back(*next_node);
next_node = graph.GetNode(next_node->OutputNodesBegin()->Index());
}
// Apex O2 pattern specific match ends...

Node& mul_node = *next_node;
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) ||
mul_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, mul_node, 1) ||
!IsSupportedDataType(mul_node)) {
continue;
}
nodes_to_remove.push_back(mul_node);
// Let Div node be last node in nodes_to_remove in case of skip_scale_bias_for_layer_norm.
nodes_to_remove.push_back(div_node);

// mul --> add
// Need not check output edges of last node since they will be moved to fused node.
Node& last_add_node = *graph.GetNode(mul_node.OutputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(last_add_node, "Add", {7, 13, 14}) ||
last_add_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!IsSupportedDataType(last_add_node)) {
continue;
}
nodes_to_remove.push_back(last_add_node);
// Skip mul and last_add_node if use null scale bias.
NodeArg* scale = nullptr;
NodeArg* bias = nullptr;
Node* last_node_ptr = nullptr;

// get axes attributes

auto axes_values = GetAxesFromReduceMeanNode(reduce_mean_node, graph);
auto axes2_values = GetAxesFromReduceMeanNode(reduce_mean2_node, graph);

Expand All @@ -423,42 +401,98 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}
#endif

// Get the inputs for the new LayerNormalization node.
// scale and bias could be multi-dims; we only support it for training at the moment
// because SkipLayerNorm kernel, for example, has dependency on single dim size
NodeArg* scale = nullptr;
NodeArg* bias = nullptr;
for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) {
if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) {
if (!skip_scale_bias_for_layer_norm) {
// Apex O2 pattern specific match starts...
// Logically since we support input and scale/bias in different data types, those Cast Ops in sub-graph
// can be removed. This is one possible place a Cast Op can exist, that is between Div and Mul nodes.
// div --> mul or div --> cast --> mul
Node* next_node = graph.GetNode(div_node.OutputNodesBegin()->Index());
if (graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Cast", {9, 13, 19}) &&
optimizer_utils::CheckOutputEdges(graph, *next_node, 1)) {
nodes_to_remove.push_back(*next_node);
next_node = graph.GetNode(next_node->OutputNodesBegin()->Index());
}
// Apex O2 pattern specific match ends...

Node& mul_node = *next_node;
last_node_ptr = &mul_node;
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) ||
mul_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, mul_node, 1) ||
!IsSupportedDataType(mul_node)) {
continue;
}
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
scale = mul_node.MutableInputDefs()[i];
nodes_to_remove.push_back(mul_node);

// mul --> add
// Need not check output edges of last node since they will be moved to fused node.
Node& last_add_node = *graph.GetNode(mul_node.OutputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(last_add_node, "Add", {7, 13, 14}) ||
last_add_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!IsSupportedDataType(last_add_node)) {
continue;
}
nodes_to_remove.push_back(last_add_node);

// Get the inputs for the new LayerNormalization node.
// scale and bias could be multi-dims; we only support it for training at the moment
// because SkipLayerNorm kernel, for example, has dependency on single dim size
for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) {
if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) {
continue;
}
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
scale = mul_node.MutableInputDefs()[i];
}
}
}

for (size_t i = 0; i < last_add_node.MutableInputDefs().size(); i++) {
if (last_add_node.MutableInputDefs()[i]->Shape() == nullptr) {
continue;
for (size_t i = 0; i < last_add_node.MutableInputDefs().size(); i++) {
if (last_add_node.MutableInputDefs()[i]->Shape() == nullptr) {
continue;
}
if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
bias = last_add_node.MutableInputDefs()[i];
}
}
if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
bias = last_add_node.MutableInputDefs()[i];
if (scale == nullptr || bias == nullptr) {
continue;
}
}
if (scale == nullptr || bias == nullptr) {
continue;
}

// Scale and bias must have the same shape.
bool same_dim = true;
for (int i = 0; i < scale->Shape()->dim_size(); i++) {
if (scale->Shape()->dim(i).dim_value() != bias->Shape()->dim(i).dim_value()) {
same_dim = false;
break;
// Scale and bias must have the same shape.
bool same_dim = true;
for (int i = 0; i < scale->Shape()->dim_size(); i++) {
if (scale->Shape()->dim(i).dim_value() != bias->Shape()->dim(i).dim_value()) {
same_dim = false;
break;
}
}
}
if (!same_dim)
continue;
if (!same_dim)
continue;
} else {
last_node_ptr = &div_node;

// Get output shape at axes_values dim.
const NodeArg* div_output_arg = div_node.OutputDefs()[0];
int dim_size = div_output_arg->Shape()->dim_size();
int reduce_mean_axes = static_cast<int>(axes_values[0]);
reduce_mean_axes = (reduce_mean_axes + dim_size) % dim_size;

int64_t reduced_axis_length = div_output_arg->Shape()->dim(reduce_mean_axes).dim_value();
int div_output_dtype = div_output_arg->TypeAsProto()->tensor_type().elem_type();

// LayerNorm requires 2 input, x and scale. Fill 1.0f to 1D vecoter scale when scale is unused.
ONNX_NAMESPACE::TensorProto scale_constant;
std::string const_name = graph.GenerateNodeName(last_node_ptr->Name() + "/LayerNormFusion/Scale");
scale_constant.set_name(const_name);
scale_constant.set_data_type(div_output_dtype);
// scale is required to have same dim-length as the output tensor's shape value on the reduced axis.
scale_constant.add_dims(reduced_axis_length);
for (int64_t i = 0; i < reduced_axis_length; ++i) {
scale_constant.add_float_data(1.0f);
}
graph.AddInitializedTensor(scale_constant);
scale = &graph.GetOrCreateNodeArg(const_name, nullptr);
} // end if (!skip_scale_bias_for_layer_norm)

NodeArg* x_input = has_leading_cast ? graph.GetNode(p_reduce_mean_input_node->Index())->MutableInputDefs()[0]
: reduce_mean_node.MutableInputDefs()[0];
Expand All @@ -469,8 +503,12 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}

InlinedVector<NodeArg*> layer_norm_input_defs{x_input, scale, bias};
Node& layer_norm_node = graph.AddNode(graph.GenerateNodeName(mul_node.Name() + "/LayerNormFusion/"),
InlinedVector<NodeArg*> layer_norm_input_defs{x_input, scale};
if (!skip_scale_bias_for_layer_norm) {
layer_norm_input_defs.push_back(bias);
}

Node& layer_norm_node = graph.AddNode(graph.GenerateNodeName(last_node_ptr->Name() + "/LayerNormFusion/"),
"LayerNormalization",
"fused LayerNorm subgraphs ",
layer_norm_input_defs,
Expand All @@ -491,7 +529,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,

// Set stash_type to double if any input is double, default value if float.
if (x_input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE ||
scale->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE) {
(!skip_scale_bias_for_layer_norm && scale->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE)) {
layer_norm_node.AddAttribute("stash_type", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE));
}

Expand Down
37 changes: 37 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test_layernorm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,43 @@ TEST_F(GraphTransformationTests, LayerNormFusionTest) {
}
}

TEST_F(GraphTransformationTests, LayerNormFusionTestWithoutScaleBias) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_without_scale_bias.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<LayerNormFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 0);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Sub"] == 0);
ASSERT_TRUE(op_to_count["ReduceMean"] == 0);
ASSERT_TRUE(op_to_count["Pow"] == 0);
ASSERT_TRUE(op_to_count["Sqrt"] == 0);
ASSERT_TRUE(op_to_count["LayerNormalization"] == 1);

for (const Node& node : graph.Nodes()) {
if (node.OpType() == "LayerNormalization") {
// LayerNormalization should have three inputs.
EXPECT_EQ(node.InputDefs().size(), 2u)
<< "LayerNormalization number of inputs does not equal to 3. Got:" << node.InputDefs().size();
const TensorShapeProto* scale_shape = node.InputDefs()[1]->Shape();
EXPECT_EQ(scale_shape->dim_size(), 1)
<< "LayerNormalization scale should be 1D. Got: " << scale_shape->dim_size();
} else {
EXPECT_TRUE(false) << "Unexpected node " << node.Name();
}
}
}

TEST_F(GraphTransformationTests, TwoLayerNormShareSameInput) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_shared_input.onnx";
std::shared_ptr<Model> p_model;
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import onnx
from onnx import OperatorSetIdProto, TensorProto, helper


def GenerateModel(model_name): # noqa: N802
nodes = [ # LayerNormWithCast2 subgraph
helper.make_node("ReduceMean", ["X"], ["rd1_out"], "reduce", axes=[-1]),
helper.make_node("Sub", ["X", "rd1_out"], ["sub1_out"], "sub"),
helper.make_node("Pow", ["sub1_out", "pow_in_2"], ["pow_out"], "pow"),
helper.make_node("ReduceMean", ["pow_out"], ["rd2_out"], "reduce2", axes=[-1]),
helper.make_node("Add", ["rd2_out", "const_0"], ["add1_out"], "add"),
helper.make_node("Sqrt", ["add1_out"], ["sqrt_out"], "sqrt"),
helper.make_node("Div", ["sub1_out", "sqrt_out"], ["Y"], "div"),
]

initializers = [ # initializers
helper.make_tensor("pow_in_2", TensorProto.FLOAT, [], [2]),
helper.make_tensor("const_0", TensorProto.FLOAT, [], [0]),
]

graph = helper.make_graph(
nodes,
"LayerNormWithoutScaleBias", # name
[ # inputs
helper.make_tensor_value_info("X", TensorProto.FLOAT, [16, 32, 4]),
],
[ # outputs
helper.make_tensor_value_info("Y", TensorProto.FLOAT, [16, 32, 4]),
],
initializers,
)

onnxdomain = OperatorSetIdProto()
onnxdomain.version = 12
# The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
onnxdomain.domain = ""
msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = "com.microsoft"
opsets = [onnxdomain, msdomain]

model = helper.make_model(graph, opset_imports=opsets)
onnx.save(model, model_name)


GenerateModel("layer_norm_without_scale_bias.onnx")