Skip to content

Fix a bug that multiple (conv, batch_norm) ops could not be optimized. #2187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ The common issues we run into we try to document here [Troubleshooting Guide](Tr

| Build Type | OS | Python | TensorFlow | ONNX opset | Status |
| --- | --- | --- | --- | --- | --- |
| Unit Test - Basic | Linux, MacOS<sup>\*</sup>, Windows<sup>\*</sup> | 3.7-3.10 | 1.13-1.15, 2.1-2.11 | 14-18 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test?branchName=main)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=16&branchName=main) |
| Unit Test - Full | Linux, MacOS, Windows | 3.7-3.10 | 1.13-1.15, 2.1-2.11 | 14-18 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test-matrix?branchName=main)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=18&branchName=main) | |
| Unit Test - Basic | Linux, MacOS<sup>\*</sup>, Windows<sup>\*</sup> | 3.7-3.10 | 1.15, 2.1-2.11 | 14-18 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test?branchName=main)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=16&branchName=main) |
| Unit Test - Full | Linux, MacOS, Windows | 3.7-3.10 | 1.15, 2.1-2.11 | 14-18 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test-matrix?branchName=main)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=18&branchName=main) | |
<br/>

## Supported Versions
Expand Down
12 changes: 0 additions & 12 deletions ci_build/azure_pipelines/onnxruntime_nightly_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,6 @@ stages:
- template: 'unit_test.yml'
report_coverage: 'True'

- template: 'templates/job_generator.yml'
parameters:
platforms: ['linux', 'windows']
python_versions: ['3.7']
tf_versions: ['1.14.0']
onnx_opsets: ['']
onnx_backends: {onnxruntime: ['nightly']}
job:
steps:
- template: 'unit_test.yml'
report_coverage: 'True'

- template: 'templates/job_generator.yml'
parameters:
platforms: ['linux', 'windows']
Expand Down
9 changes: 0 additions & 9 deletions ci_build/azure_pipelines/pretrained_model_test-matrix.yml
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
# Pre-trained model test, full matrix

jobs:
- template: 'templates/job_generator.yml'
parameters:
platforms: ['linux', 'windows']
python_versions: ['3.7']
tf_versions: ['1.14.0']
job:
steps:
- template: 'pretrained_model_test.yml'

- template: 'templates/job_generator.yml'
parameters:
platforms: ['linux', 'windows']
Expand Down
2 changes: 1 addition & 1 deletion ci_build/azure_pipelines/unit_test-matrix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ stages:
parameters:
platforms: ['linux', 'windows']
python_versions: ['3.7']
tf_versions: ['1.14.0', '1.15.2']
tf_versions: ['1.15.2']
onnx_opsets: ['']
job:
steps:
Expand Down
10 changes: 0 additions & 10 deletions ci_build/azure_pipelines/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,6 @@ stages:
- template: 'unit_test.yml'
report_coverage: 'True'

- template: 'templates/job_generator.yml'
parameters:
platforms: ['windows']
tf_versions: ['1.14.0']
onnx_opsets: ['14']
job:
steps:
- template: 'unit_test.yml'
report_coverage: 'True'

- template: 'templates/job_generator.yml'
parameters:
python_versions: ['3.8']
Expand Down
51 changes: 51 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3087,6 +3087,57 @@ def graph_validator(g):

self._run_test_case(func_fusedbn, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05, graph_validator=graph_validator)

@check_opset_min_version(7, "batchnorm")
def test_multiple_conv2d_fused_batchnorm(self):
x_shape = [1, 28, 28, 2]
x_val = np.random.random_sample(x_shape).astype(np.float32)
w = np.array([[2., 1., 1.],
[1., 3., 1.],
[1., 1., 4.]], dtype=np.float32).reshape(_KERNEL3x3)
# 2 channels for input and output
w = np.concatenate([w, w, w, w]).reshape([3, 3, 2, 2])
scale_dtype = np.float32
scale_shape = x_shape[-1:]
scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
mean_val = np.random.random_sample(scale_shape).astype(scale_dtype)
var_val = np.random.random_sample(scale_shape).astype(scale_dtype)

def func_conv2d(x):
kernel = tf.constant(w, dtype=tf.float32, name='k')
conv = tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding='VALID')
return conv

def func_multiple_fusedbn(x):
scale = tf.constant(scale_val, name='scale')
offset = tf.constant(offset_val, name='offset')
mean = tf.constant(mean_val, name='mean')
var = tf.constant(var_val, name='variance')
epsilon = 0.1234
y, _, _ = fused_batch_norm(
func_conv2d(x), scale, offset, mean=mean, variance=var,
epsilon=epsilon, data_format='NHWC', is_training=False)

y = tf.nn.relu(y)

y, _, _ = fused_batch_norm(
func_conv2d(y), scale, offset, mean=mean, variance=var,
epsilon=epsilon, data_format='NHWC', is_training=False)

y, _, _ = fused_batch_norm(
func_conv2d(y), scale, offset, mean=mean, variance=var,
epsilon=epsilon, data_format='NHWC', is_training=False)

return tf.identity(y, name=_TFOUTPUT)

def graph_validator(g):
if 'BatchNormalization' in [n.type for n in g.get_nodes()]:
return False
return True

self._run_test_case(func_multiple_fusedbn, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05,
graph_validator=graph_validator)

@check_tf_min_version("1.15")
@check_opset_min_version(10, "quantize_and_dequantize")
def test_qdq_unsigned_input(self):
Expand Down
5 changes: 3 additions & 2 deletions tf2onnx/optimizer/back_to_back_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ def _optimize_at_current_graph_level(self, g):

# topological sort of candidates
# simplifying assumption for back-to-back-optimizer is
# the op_types have 1 input, 1 output, but multiple consumers
# the op_types have 1 input, 1 output, but multiple consumers.
# if optype contains 2 elements, the second element should not be considered as a consumer.
has_dependencies = set()
consumer_node_ids = {n.output[0]: [] for n in nodes}
consumer_node_ids = {n.output[0]: [] for n in nodes if len(optype) < 2 or n.type == optype[0]}
for n in nodes:
if n.input[0] in consumer_node_ids:
consumer_node_ids[n.input[0]].extend([n])
Expand Down