@@ -30,6 +30,25 @@ class REWRITER_RESULT(Enum):
30
30
31
31
# TensorFlow LSTMCell/BasicLSTMCell and Keras LSTM computation graph matching
32
32
33
+ def insert_activation (activation , name = "" , inputs = None ):
34
+ inputs = inputs if inputs else [] # to avoid empty list as default arg
35
+ if activation == "hard_sigmoid" :
36
+ return OpTypePattern ("Maximum" , inputs = [
37
+ OpTypePattern ("Minimum" , inputs = [
38
+ OpTypePattern ("Add|AddV2" , inputs = [
39
+ OpTypePattern ("Mul" , inputs = [
40
+ * inputs ,
41
+ OpTypePattern ("*" ) # mul(x, 0.2)
42
+ ]), OpTypePattern ("*" ) # add(x, 0.5)
43
+ ]), OpTypePattern ("*" ) # minimum(x, 1)
44
+ ]), OpTypePattern ("*" ) # maximum(x, 0)
45
+ ])
46
+ # Additional activation pattern can be added when needed:
47
+ # https://www.tensorflow.org/api_docs/python/tf/keras/activations
48
+ # otherwise, use default activations
49
+ return OpTypePattern ("Tanh|Relu|Sigmoid" , name = name , inputs = inputs )
50
+
51
+
33
52
def make_lstm_xc_pattern (enter_or_id = "Enter" , from_keras = False , use_bias = False ):
34
53
if from_keras :
35
54
lstm_xh_pattern = OpTypePattern ("Add|AddV2" , allow_reorder = False , inputs = [
@@ -63,7 +82,8 @@ def make_lstm_xc_pattern(enter_or_id="Enter", from_keras=False, use_bias=False):
63
82
])
64
83
65
84
66
- def make_lstm_pattern (enter_or_id = "Enter" , from_keras = False , use_bias = False ):
85
+ def make_lstm_pattern (enter_or_id = "Enter" , from_keras = False , use_bias = False ,
86
+ activation = "" , recurrent_activation = "" ):
67
87
# split (Xt*(W[ifco]^T) + Ht-1*(R[ifco]^T)) on 'Const' axis
68
88
lstm_xc_pattern = OpTypePattern ('Split' , inputs = [
69
89
OpTypePattern ("Const" ),
@@ -77,23 +97,21 @@ def make_lstm_pattern(enter_or_id="Enter", from_keras=False, use_bias=False):
77
97
OpTypePattern ("*" , name = "ft_bias" ),
78
98
])
79
99
80
- activation = "Tanh|Relu|Sigmoid"
81
- recurrent_activation = "Tanh|Relu|Sigmoid"
82
-
83
- return OpTypePattern ("Mul" , name = 'ht' , inputs = [
84
- OpTypePattern (recurrent_activation , name = "ot" , inputs = [lstm_xc_pattern ]),
85
- OpTypePattern (activation , name = "ct'" , inputs = [
86
- OpTypePattern ("Add|AddV2" , name = "ct" , inputs = [
87
- OpTypePattern ("Mul" , name = "ct_identity_consumer" , inputs = [
88
- OpTypePattern (recurrent_activation , name = "ft" , inputs = [lstm_fb_pattern ]),
89
- OpTypePattern ("*" , name = "c" ),
90
- ]),
91
- OpTypePattern ("Mul" , inputs = [
92
- OpTypePattern (recurrent_activation , name = "it" , inputs = [lstm_xc_pattern ]),
93
- OpTypePattern (activation , name = "gt" , inputs = [lstm_xc_pattern ]),
94
- ]),
95
- ]),
100
+ # cell state
101
+ lstm_ct_pattern = OpTypePattern ("Add|AddV2" , name = "ct" , inputs = [
102
+ OpTypePattern ("Mul" , name = "ct_identity_consumer" , inputs = [
103
+ insert_activation (recurrent_activation , name = "ft" , inputs = [lstm_fb_pattern ]),
104
+ OpTypePattern ("*" , name = "c" ),
96
105
]),
106
+ OpTypePattern ("Mul" , inputs = [
107
+ insert_activation (recurrent_activation , name = "it" , inputs = [lstm_xc_pattern ]),
108
+ insert_activation (activation , name = "gt" , inputs = [lstm_xc_pattern ]),
109
+ ]),
110
+ ])
111
+
112
+ return OpTypePattern ("Mul" , name = "ht" , inputs = [
113
+ insert_activation (recurrent_activation , name = "ot" , inputs = [lstm_xc_pattern ]),
114
+ insert_activation (activation , name = "ct'" , inputs = [lstm_ct_pattern ]),
97
115
])
98
116
99
117
lstmcell_pattern = make_lstm_pattern ()
0 commit comments