Skip to content

Commit 0df1bf7

Browse files
feat: LayerNormalization
Co-authored-by: Mason Ma <[email protected]>
1 parent ef6f06a commit 0df1bf7

File tree

4 files changed

+156
-0
lines changed

4 files changed

+156
-0
lines changed

onnx2torch/node_converters/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from onnx2torch.node_converters.global_average_pool import *
2222
from onnx2torch.node_converters.identity import *
2323
from onnx2torch.node_converters.instance_norm import *
24+
from onnx2torch.node_converters.layer_norm import *
2425
from onnx2torch.node_converters.logical import *
2526
from onnx2torch.node_converters.lrn import *
2627
from onnx2torch.node_converters.matmul import *
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
__all__ = [
2+
'OnnxLayerNorm',
3+
]
4+
5+
from typing import Optional
6+
7+
import torch
8+
import torch.nn.functional as F
9+
from torch import nn
10+
11+
from onnx2torch.node_converters.registry import add_converter
12+
from onnx2torch.onnx_graph import OnnxGraph
13+
from onnx2torch.onnx_node import OnnxNode
14+
from onnx2torch.utils.common import OnnxMapping
15+
from onnx2torch.utils.common import OnnxToTorchModule
16+
from onnx2torch.utils.common import OperationConverterResult
17+
from onnx2torch.utils.common import get_shape_from_value_info
18+
from onnx2torch.utils.common import onnx_mapping_from_node
19+
20+
AXIS_DEFAULT_VALUE = -1
21+
EPSILON_DEFAULT_VALUE = 1e-5
22+
23+
24+
class OnnxLayerNorm(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring
25+
def __init__(self, axis: int, epsilon: float):
26+
super().__init__()
27+
self.axis = axis
28+
self.epsilon = epsilon
29+
30+
def forward( # pylint: disable=missing-function-docstring
31+
self,
32+
inputs: torch.Tensor,
33+
scale: torch.Tensor,
34+
bias: Optional[torch.Tensor] = None,
35+
) -> torch.Tensor:
36+
normalized_shape = inputs.shape[self.axis :]
37+
return F.layer_norm(
38+
input=inputs,
39+
normalized_shape=normalized_shape,
40+
weight=scale,
41+
bias=bias,
42+
eps=self.epsilon,
43+
)
44+
45+
46+
@add_converter(operation_type='LayerNormalization', version=17)
47+
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
48+
node_attributes = node.attributes
49+
50+
axis = node_attributes.get('axis', AXIS_DEFAULT_VALUE)
51+
epsilon = node_attributes.get('epsilon', EPSILON_DEFAULT_VALUE)
52+
53+
if all(value_name in graph.initializers for value_name in node.input_values[1:]):
54+
input_value_info = graph.value_info[node.input_values[0]]
55+
input_shape = get_shape_from_value_info(input_value_info)
56+
57+
torch_module = nn.LayerNorm(
58+
normalized_shape=input_shape[axis:],
59+
eps=epsilon,
60+
elementwise_affine=True,
61+
)
62+
63+
scale_value_name = node.input_values[1]
64+
bias_value_name = node.input_values[2] if len(node.input_values) > 2 else None
65+
66+
with torch.no_grad():
67+
torch_module.weight.data = graph.initializers[scale_value_name].to_torch()
68+
if bias_value_name is not None:
69+
torch_module.bias.data = graph.initializers[bias_value_name].to_torch()
70+
71+
onnx_mapping = OnnxMapping(inputs=(node.input_values[0],), outputs=node.output_values)
72+
else:
73+
input_value_info = graph.value_info[node.input_values[0]]
74+
input_shape = get_shape_from_value_info(input_value_info)
75+
torch_module = OnnxLayerNorm(axis=axis, epsilon=epsilon)
76+
onnx_mapping = onnx_mapping_from_node(node)
77+
78+
return OperationConverterResult(torch_module=torch_module, onnx_mapping=onnx_mapping)

operators.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ Minimal tested opset version 9, maximum tested opset version 16, recommended ops
6363
| InstanceNormalization | Y | |
6464
| IsInf | N | |
6565
| IsNaN | N | |
66+
| LayerNormalization | Y | LayerNormalization outputs "Mean" and "InvStdDev" are not implemented |
6667
| LRN | Y | |
6768
| LSTM | N | |
6869
| LeakyRelu | Y | |
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# pylint: disable=missing-function-docstring
2+
from typing import List
3+
from typing import Optional
4+
5+
import numpy as np
6+
import onnx
7+
import pytest
8+
9+
from tests.utils.common import check_onnx_model
10+
from tests.utils.common import make_model_from_nodes
11+
12+
13+
def _test_layer_norm(
14+
x: np.ndarray,
15+
scale: np.ndarray,
16+
bias: Optional[np.ndarray],
17+
axis: int,
18+
parameters_as_inputs: bool,
19+
) -> None:
20+
inputs = {'input': x}
21+
parameters = {'scale': scale}
22+
if bias is not None:
23+
parameters['bias'] = bias
24+
25+
initializers = {}
26+
27+
if parameters_as_inputs:
28+
inputs.update(parameters)
29+
else:
30+
initializers.update(parameters)
31+
32+
node = onnx.helper.make_node(
33+
op_type='LayerNormalization',
34+
inputs=['input', 'scale', 'bias'] if bias is not None else ['input', 'scale'],
35+
outputs=['y'],
36+
axis=axis,
37+
)
38+
model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=inputs, opset_version=17)
39+
check_onnx_model(
40+
onnx_model=model,
41+
onnx_inputs=inputs,
42+
atol_onnx_torch=1e-5,
43+
atol_torch_cpu_cuda=1e-5,
44+
atol_onnx_torch2onnx=1e-5,
45+
)
46+
47+
48+
@pytest.mark.parametrize('parameters_as_inputs', (True, False))
49+
@pytest.mark.parametrize(
50+
'input_shape',
51+
(
52+
[2, 3, 16],
53+
[3, 1, 224],
54+
[4, 3, 16, 16],
55+
[5, 1, 32, 32],
56+
[6, 3, 16, 16, 8],
57+
[7, 1, 7, 7, 16],
58+
),
59+
)
60+
def test_layer_norm(input_shape: List[int], parameters_as_inputs: bool) -> None:
61+
x = np.random.randn(*input_shape).astype(np.float32)
62+
63+
for axis in [*range(len(input_shape))] + [-1]:
64+
normalized_shape = input_shape[axis:]
65+
66+
scale = np.random.randn(*normalized_shape).astype(np.float32)
67+
bias = np.random.randn(*normalized_shape).astype(np.float32)
68+
69+
for bias_ in [bias, None]:
70+
_test_layer_norm(
71+
x=x,
72+
scale=scale,
73+
bias=bias_,
74+
axis=axis,
75+
parameters_as_inputs=parameters_as_inputs,
76+
)

0 commit comments

Comments
 (0)