Skip to content

Commit 0d4ec08

Browse files
feat: InstanceNormalization
feat: InstanceNormalization
2 parents c2ba45a + 903f43f commit 0d4ec08

File tree

4 files changed

+137
-1
lines changed

4 files changed

+137
-1
lines changed

onnx2torch/node_converters/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from onnx2torch.node_converters.gemm import *
2020
from onnx2torch.node_converters.global_average_pool import *
2121
from onnx2torch.node_converters.identity import *
22+
from onnx2torch.node_converters.instance_norm import *
2223
from onnx2torch.node_converters.logical import *
2324
from onnx2torch.node_converters.lrn import *
2425
from onnx2torch.node_converters.matmul import *
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
__all__ = [
2+
'OnnxInstanceNorm',
3+
]
4+
5+
import torch
6+
import torch.nn.functional as F
7+
from torch import nn
8+
9+
from onnx2torch.node_converters.registry import add_converter
10+
from onnx2torch.onnx_graph import OnnxGraph
11+
from onnx2torch.onnx_node import OnnxNode
12+
from onnx2torch.utils.common import OnnxMapping
13+
from onnx2torch.utils.common import OnnxToTorchModule
14+
from onnx2torch.utils.common import OperationConverterResult
15+
from onnx2torch.utils.common import get_shape_from_value_info
16+
from onnx2torch.utils.common import onnx_mapping_from_node
17+
18+
_IN_CLASS_FROM_SPATIAL_RANK = {
19+
0: nn.InstanceNorm1d,
20+
1: nn.InstanceNorm1d,
21+
2: nn.InstanceNorm2d,
22+
3: nn.InstanceNorm3d,
23+
}
24+
25+
26+
class OnnxInstanceNorm(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring
27+
def __init__(self, momentum: float, epsilon: float):
28+
super().__init__()
29+
self.momentum = momentum
30+
self.epsilon = epsilon
31+
32+
def forward( # pylint: disable=missing-function-docstring
33+
self,
34+
input_data: torch.Tensor,
35+
weight: torch.Tensor,
36+
bias: torch.Tensor,
37+
) -> torch.Tensor:
38+
return F.instance_norm(
39+
input=input_data,
40+
running_mean=None,
41+
running_var=None,
42+
weight=weight,
43+
bias=bias,
44+
use_input_stats=True,
45+
momentum=self.momentum,
46+
eps=self.epsilon,
47+
)
48+
49+
50+
@add_converter(operation_type='InstanceNormalization', version=1)
51+
@add_converter(operation_type='InstanceNormalization', version=6)
52+
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
53+
node_attributes = node.attributes
54+
epsilon = node_attributes.get('epsilon', 1e-5)
55+
momentum = 0.1
56+
57+
if all(value_name in graph.initializers for value_name in node.input_values[1:]):
58+
input_value_info = graph.value_info[node.input_values[0]]
59+
input_shape = get_shape_from_value_info(input_value_info)
60+
spatial_rank = len(input_shape) - 2
61+
try:
62+
in_class = _IN_CLASS_FROM_SPATIAL_RANK[spatial_rank]
63+
except KeyError as exc:
64+
raise NotImplementedError(
65+
f'InstanceNorm operation with spatial rank == {spatial_rank} is not implemented'
66+
) from exc
67+
68+
scale_value_name = node.input_values[1]
69+
bias_value_name = node.input_values[2]
70+
71+
scale = graph.initializers[scale_value_name].to_torch()
72+
torch_module = in_class(
73+
num_features=scale.size()[0],
74+
eps=epsilon,
75+
momentum=momentum,
76+
affine=True,
77+
track_running_stats=False,
78+
)
79+
with torch.no_grad():
80+
torch_module.weight.data = graph.initializers[scale_value_name].to_torch()
81+
torch_module.bias.data = graph.initializers[bias_value_name].to_torch()
82+
83+
onnx_mapping = OnnxMapping(inputs=(node.input_values[0],), outputs=node.output_values)
84+
else:
85+
torch_module = OnnxInstanceNorm(momentum=momentum, epsilon=epsilon)
86+
onnx_mapping = onnx_mapping_from_node(node)
87+
88+
return OperationConverterResult(torch_module=torch_module, onnx_mapping=onnx_mapping)

operators.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ Minimal tested opset version 9, maximum tested opset version 16, recommended ops
6060
| Hardmax | N | |
6161
| Identity | Y | |
6262
| If | N | |
63-
| InstanceNormalization | N | |
63+
| InstanceNormalization | Y | |
6464
| IsInf | N | |
6565
| IsNaN | N | |
6666
| LRN | Y | |
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import List
2+
3+
import numpy as np
4+
import onnx
5+
import pytest
6+
7+
from tests.utils.common import check_onnx_model
8+
from tests.utils.common import make_model_from_nodes
9+
10+
11+
@pytest.mark.parametrize('parameters_as_inputs', (True, False))
12+
@pytest.mark.parametrize(
13+
'input_shape',
14+
(
15+
# 1d
16+
[2, 3, 16],
17+
[2, 1, 7],
18+
# 2d
19+
[2, 3, 16, 16],
20+
[2, 1, 7, 16],
21+
# 3d
22+
[2, 3, 16, 16, 16],
23+
[2, 1, 16, 7, 16],
24+
),
25+
)
26+
def test_instance_norm( # pylint: disable=missing-function-docstring
27+
input_shape: List[int],
28+
parameters_as_inputs: bool,
29+
) -> None:
30+
num_features = input_shape[1]
31+
x = np.random.uniform(low=-1.0, high=1.0, size=input_shape).astype(np.float32)
32+
scale = np.random.uniform(low=0.0, high=1.0, size=num_features).astype(np.float32)
33+
bias = np.random.uniform(low=-1.0, high=1.0, size=num_features).astype(np.float32)
34+
35+
inputs = {'input': x}
36+
parameters = {'scale': scale, 'bias': bias}
37+
initializers = {}
38+
39+
if parameters_as_inputs:
40+
inputs.update(parameters)
41+
else:
42+
initializers.update(parameters)
43+
44+
node = onnx.helper.make_node(op_type='InstanceNormalization', inputs=['input', 'scale', 'bias'], outputs=['y'])
45+
46+
model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=inputs)
47+
check_onnx_model(onnx_model=model, onnx_inputs=inputs, atol_onnx_torch=1e-6, atol_torch_cpu_cuda=1e-6)

0 commit comments

Comments
 (0)