@@ -2817,6 +2817,31 @@ def func(x):
2817
2817
return tf .identity (y , name = _TFOUTPUT )
2818
2818
self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-04 )
2819
2819
2820
+ @check_opset_min_version (7 , "batchnorm" )
2821
+ @check_tf_min_version ("2.0" , "tf-1.x does not support NDHWC" )
2822
+ def test_fused_batchnorm_3d (self ):
2823
+ x_shape = [1 , 28 , 28 , 2 , 2 ]
2824
+ x_dtype = np .float32
2825
+ scale_dtype = np .float32
2826
+ scale_shape = [2 ]
2827
+ data_format = "NDHWC"
2828
+ x_val = np .random .random_sample (x_shape ).astype (x_dtype )
2829
+ scale_val = np .random .random_sample (scale_shape ).astype (scale_dtype )
2830
+ offset_val = np .random .random_sample (scale_shape ).astype (scale_dtype )
2831
+ mean_val = np .random .random_sample (scale_shape ).astype (scale_dtype )
2832
+ var_val = np .random .random_sample (scale_shape ).astype (scale_dtype )
2833
+ def func (x ):
2834
+ scale = tf .constant (scale_val , name = 'scale' )
2835
+ offset = tf .constant (offset_val , name = 'offset' )
2836
+ mean = tf .constant (mean_val , name = 'mean' )
2837
+ var = tf .constant (var_val , name = 'variance' )
2838
+ epsilon = 0.001
2839
+ y , _ , _ = fused_batch_norm (
2840
+ x , scale , offset , mean = mean , variance = var ,
2841
+ epsilon = epsilon , data_format = data_format , is_training = False )
2842
+ return tf .identity (y , name = _TFOUTPUT )
2843
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-04 )
2844
+
2820
2845
@check_opset_min_version (7 , "batchnorm" )
2821
2846
@skip_tfjs ("TFJS executes model incorrectly" )
2822
2847
def test_fused_batchnorm_training (self ):
0 commit comments