Skip to content

Commit 5dfd36f

Browse files
authored
#1749: Fix fused_batch_norm 5d NDHWC input reshape convert (#1769)
* fix fused_batch_norm 5d input reshape convert * fix for #1749 Signed-off-by: hwangdeyu <[email protected]>
1 parent d87ba34 commit 5dfd36f

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

tests/test_backend.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2817,6 +2817,31 @@ def func(x):
28172817
return tf.identity(y, name=_TFOUTPUT)
28182818
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04)
28192819

2820+
@check_opset_min_version(7, "batchnorm")
2821+
@check_tf_min_version("2.0", "tf-1.x does not support NDHWC")
2822+
def test_fused_batchnorm_3d(self):
2823+
x_shape = [1, 28, 28, 2, 2]
2824+
x_dtype = np.float32
2825+
scale_dtype = np.float32
2826+
scale_shape = [2]
2827+
data_format = "NDHWC"
2828+
x_val = np.random.random_sample(x_shape).astype(x_dtype)
2829+
scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
2830+
offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
2831+
mean_val = np.random.random_sample(scale_shape).astype(scale_dtype)
2832+
var_val = np.random.random_sample(scale_shape).astype(scale_dtype)
2833+
def func(x):
2834+
scale = tf.constant(scale_val, name='scale')
2835+
offset = tf.constant(offset_val, name='offset')
2836+
mean = tf.constant(mean_val, name='mean')
2837+
var = tf.constant(var_val, name='variance')
2838+
epsilon = 0.001
2839+
y, _, _ = fused_batch_norm(
2840+
x, scale, offset, mean=mean, variance=var,
2841+
epsilon=epsilon, data_format=data_format, is_training=False)
2842+
return tf.identity(y, name=_TFOUTPUT)
2843+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04)
2844+
28202845
@check_opset_min_version(7, "batchnorm")
28212846
@skip_tfjs("TFJS executes model incorrectly")
28222847
def test_fused_batchnorm_training(self):

tf2onnx/onnx_opset/nn.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,14 @@ class BatchNorm:
943943
@classmethod
944944
def version_6(cls, ctx, node, **kwargs):
945945
tf_type = node.type
946+
input_rank = len(ctx.get_shape(node.input[0]))
947+
if input_rank == 4:
948+
spatial = 2
949+
elif input_rank == 5:
950+
spatial = 3
951+
else:
952+
raise ValueError("node input must be 4 or 5-dimensional, is {} now".format(input_rank))
953+
946954
node.type = "BatchNormalization"
947955
# tf inputs: x, scale, bias, mean, variance
948956
# tf outputs: y, batch_mean, batch_var
@@ -973,7 +981,7 @@ def version_6(cls, ctx, node, **kwargs):
973981
# the setter makes a copy of new_output
974982
node.output = new_output
975983

976-
conv_convert_inputs(ctx, node, with_kernel=False)
984+
conv_convert_inputs(ctx, node, with_kernel=False, spatial=spatial)
977985

978986
inp_shape = ctx.get_shape(node.input[0])
979987
inp_rank = len(inp_shape) if inp_shape is not None else None

0 commit comments

Comments
 (0)