|
13 | 13 | # pylint: disable=missing-docstring
|
14 | 14 |
|
15 | 15 | def rewrite_biasadd_with_conv2d(g, ops):
|
16 |
| - pattern = \ |
| 16 | + pattern1 = \ |
17 | 17 | OpTypePattern('BiasAdd', name='biasadd', inputs=[
|
18 | 18 | OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=['*', '*']), '*'])
|
19 |
| - matcher = GraphMatcher(pattern) |
20 |
| - match_results = list(matcher.match_ops(ops)) |
21 |
| - for match in match_results: |
22 |
| - biasadd = match.get_op('biasadd') |
23 |
| - conv = match.get_op('conv') |
24 |
| - |
25 |
| - #backup the conv and biasadd values |
26 |
| - conv_type = conv.type |
27 |
| - conv_input = conv.input |
28 |
| - conv_attr = conv.attr |
29 |
| - dtype = g.get_dtype(conv.output[0]) |
30 |
| - shape = g.get_shape(conv.output[0]) |
31 |
| - conv_name = biasadd.name |
32 |
| - conv_output = biasadd.output |
33 |
| - conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]] |
34 |
| - |
35 |
| - if len(g.find_output_consumers(conv.output[0])) > 1: |
36 |
| - continue |
37 |
| - # Remove the Conv and BiasAdd node |
38 |
| - g.remove_node(conv.name) |
39 |
| - g.remove_node(biasadd.name) |
40 |
| - |
41 |
| - g.make_node(conv_type, conv_inputs, attr=conv_attr, name=conv_name, outputs=conv_output, |
42 |
| - shapes=[shape], dtypes=[dtype], skip_conversion=False) |
| 19 | + pattern2 = \ |
| 20 | + OpTypePattern('BiasAdd', name='biasadd', inputs=[ |
| 21 | + OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=['*', '*', '*']), '*'], allow_reorder = True) |
| 22 | + |
| 23 | + for pattern in [pattern1, pattern2]: |
| 24 | + matcher = GraphMatcher(pattern) |
| 25 | + match_results = list(matcher.match_ops(ops)) |
| 26 | + for match in match_results: |
| 27 | + biasadd = match.get_op('biasadd') |
| 28 | + conv = match.get_op('conv') |
| 29 | + |
| 30 | + #backup the conv and biasadd values |
| 31 | + conv_type = conv.type |
| 32 | + conv_input = conv.input |
| 33 | + conv_attr = conv.attr |
| 34 | + dtype = g.get_dtype(conv.output[0]) |
| 35 | + shape = g.get_shape(conv.output[0]) |
| 36 | + conv_name = biasadd.name |
| 37 | + conv_output = biasadd.output |
| 38 | + if pattern == pattern2: |
| 39 | + conv_inputs = [conv_input[0], conv_input[1], conv_input[2], biasadd.input[1]] |
| 40 | + else: |
| 41 | + conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]] |
| 42 | + |
| 43 | + if len(g.find_output_consumers(conv.output[0])) > 1: |
| 44 | + continue |
| 45 | + # Remove the Conv and BiasAdd node |
| 46 | + g.remove_node(conv.name) |
| 47 | + g.remove_node(biasadd.name) |
| 48 | + |
| 49 | + g.make_node(conv_type, conv_inputs, attr=conv_attr, name=conv_name, outputs=conv_output, |
| 50 | + shapes=[shape], dtypes=[dtype], skip_conversion=False) |
43 | 51 | return ops
|
0 commit comments