File tree Expand file tree Collapse file tree 1 file changed +11
-5
lines changed Expand file tree Collapse file tree 1 file changed +11
-5
lines changed Original file line number Diff line number Diff line change @@ -33,16 +33,22 @@ def rewrite_random_normal(g, ops):
33
33
match_results = list (matcher .match_ops (ops ))
34
34
for match in match_results :
35
35
output = match .get_op ('output' )
36
- if output .type == 'Add' :
36
+ input2 = match .get_op ('input2' )
37
+ is_output = False
38
+ for output_name in g .outputs :
39
+ # input2 and output can not be output node.
40
+ if input2 .name in output_name or output .name in output_name :
41
+ is_output = True
42
+ break
43
+ if is_output :
44
+ continue
45
+ if output .type == 'Add' and input2 .type == 'Mul' :
37
46
# pattern 1
38
47
mean = output .inputs [1 ].get_tensor_value ()
48
+ scale = input2 .inputs [1 ].get_tensor_value ()
39
49
else :
40
50
# pattern 2
41
51
mean = 0.0
42
- input2 = match .get_op ('input2' )
43
- if input2 .type == 'Mul' :
44
- scale = input2 .inputs [1 ].get_tensor_value ()
45
- else :
46
52
scale = 1.0
47
53
dtype = g .get_dtype (output .output [0 ])
48
54
op_name = utils .make_name ("RandomNormal" )
You can’t perform that action at this time.
0 commit comments