Skip to content

Commit 71105c1

Browse files
authored
Add handling of HardSigmoid recurrent activation for Keras LSTM (#2001)
Add pattern matching and parsing where Keras LSTM uses `HardSigmoid` as the recurrent activation. Signed-off-by: Yu Cong <[email protected]>
1 parent e7f39ed commit 71105c1

File tree

3 files changed

+94
-31
lines changed

3 files changed

+94
-31
lines changed

tests/test_lstm.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,28 @@ def func(x):
751751
return tf.identity(y[0], name="output"), tf.identity(y[1], name="output1")
752752
self.run_test_case(func, {"input:0": x_val}, [], ["output:0", "output1:0"], rtol=1e-05, atol=1e-06)
753753

754+
@check_tf_min_version("2.0")
755+
@skip_tf_versions("2.1", "Bug in TF 2.1")
756+
def test_keras_lstm_recurrent_activation_is_hard_sigmoid(self):
757+
in_shape = [10, 3]
758+
x_val = np.random.uniform(size=[2, 10, 3]).astype(np.float32)
759+
760+
model_in = tf.keras.layers.Input(tuple(in_shape), batch_size=2)
761+
x = tf.keras.layers.LSTM(
762+
units=5,
763+
return_sequences=True,
764+
return_state=True,
765+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
766+
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
767+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
768+
recurrent_activation="hard_sigmoid"
769+
)(model_in)
770+
model = tf.keras.models.Model(inputs=model_in, outputs=x)
771+
772+
def func(x):
773+
y = model(x)
774+
return tf.identity(y[0], name="output"), tf.identity(y[1], name="output1")
775+
self.run_test_case(func, {"input:0": x_val}, [], ["output:0", "output1:0"], rtol=1e-05, atol=1e-06)
754776

755777
if __name__ == '__main__':
756778
unittest_main()

tf2onnx/rewriter/lstm_tf2_rewriter.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,52 @@
1616

1717
# pylint: disable=invalid-name,unused-argument,missing-docstring, unused-variable
1818

19+
def _make_lstm_pattern_from_params(params):
20+
return make_lstm_pattern(enter_or_id="Identity") if not params.get("from_keras", False) \
21+
else make_lstm_pattern(
22+
from_keras=True,
23+
use_bias=params.get("use_bias", False),
24+
activation=params.get("activation", ""),
25+
recurrent_activation=params.get("recurrent_activation", "")
26+
)
1927

2028
def rewriter_lstm_tf2(g, ops):
21-
22-
pattern1 = make_lstm_pattern(enter_or_id="Identity") # TF LSTM
23-
pattern2 = make_lstm_pattern(from_keras=True, use_bias=False) # keras LSTM
24-
pattern3 = make_lstm_pattern(from_keras=True, use_bias=True) # keras LSTM with bias
25-
26-
for pattern in [pattern1, pattern2, pattern3]:
29+
lstm_params_variations = [
30+
# default activations
31+
{"enter_or_id": "Identity"}, # TF LSTM
32+
{"from_keras": True, "use_bias": False}, # keras LSTM
33+
{"from_keras": True, "use_bias": True}, # keras LSTM with bias
34+
# hard sigmoid as recurrent activation
35+
{"from_keras": True, "use_bias": False, "recurrent_activation": "hard_sigmoid"}, # keras LSTM
36+
{"from_keras": True, "use_bias": True, "recurrent_activation": "hard_sigmoid"} # keras LSTM with bias
37+
# Note: add other LSTM variations as needed
38+
]
39+
for params in lstm_params_variations:
40+
pattern = _make_lstm_pattern_from_params(params)
2741
matcher = GraphMatcher(pattern, allow_reorder=False)
2842
match_results = list(matcher.match_ops(ops))
2943

3044
for match_result in match_results:
31-
from_keras = pattern != pattern1
45+
is_ft_hard_sigmoid = params.get("recurrent_activation", "") == "hard_sigmoid"
46+
recurrent_activation_f = "HardSigmoid" if is_ft_hard_sigmoid else \
47+
match_result.get_op("ft").type
48+
activation_g = match_result.get_op("gt").type
49+
activation_h = match_result.get_op("ct'").type
50+
51+
default_activations = ["Relu", "Sigmoid", "Tanh"]
52+
if ((activation_g not in default_activations) or
53+
(activation_h not in default_activations) or
54+
(not is_ft_hard_sigmoid and recurrent_activation_f not in default_activations)):
55+
continue
56+
3257
activations_fgh = [
33-
match_result.get_op("ft").type,
34-
match_result.get_op("gt").type,
35-
match_result.get_op("ct'").type
58+
recurrent_activation_f,
59+
activation_g,
60+
activation_h
3661
]
37-
supported_activations = ['Relu', 'Sigmoid', 'Tanh']
38-
if any(f not in supported_activations for f in activations_fgh):
39-
continue
4062

4163
# extract input x_t
64+
from_keras = params.get("from_keras", False)
4265
if from_keras:
4366
get_item = match_result.get_op("xt")
4467
else:
@@ -134,7 +157,7 @@ def has_tensor_list_consumer(n):
134157

135158
# Wb and Rb are concatenated
136159
b_idx = None
137-
if pattern is pattern3:
160+
if from_keras and params.get("use_bias", False):
138161
bias_add = match_result.get_op("bias_add")
139162
if bias_add is not None and bias_add.data_format != "NHWC":
140163
continue

tf2onnx/rewriter/rnn_utils.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,25 @@ class REWRITER_RESULT(Enum):
3030

3131
# TensorFlow LSTMCell/BasicLSTMCell and Keras LSTM computation graph matching
3232

33+
def insert_activation(activation, name="", inputs=None):
34+
inputs = inputs if inputs else [] # to avoid empty list as default arg
35+
if activation == "hard_sigmoid":
36+
return OpTypePattern("Maximum", inputs=[
37+
OpTypePattern("Minimum", inputs=[
38+
OpTypePattern("Add|AddV2", inputs=[
39+
OpTypePattern("Mul", inputs=[
40+
*inputs,
41+
OpTypePattern("*") # mul(x, 0.2)
42+
]), OpTypePattern("*") # add(x, 0.5)
43+
]), OpTypePattern("*") # minimum(x, 1)
44+
]), OpTypePattern("*") # maximum(x, 0)
45+
])
46+
# Additional activation pattern can be added when needed:
47+
# https://www.tensorflow.org/api_docs/python/tf/keras/activations
48+
# otherwise, use default activations
49+
return OpTypePattern("Tanh|Relu|Sigmoid", name=name, inputs=inputs)
50+
51+
3352
def make_lstm_xc_pattern(enter_or_id="Enter", from_keras=False, use_bias=False):
3453
if from_keras:
3554
lstm_xh_pattern = OpTypePattern("Add|AddV2", allow_reorder=False, inputs=[
@@ -63,7 +82,8 @@ def make_lstm_xc_pattern(enter_or_id="Enter", from_keras=False, use_bias=False):
6382
])
6483

6584

66-
def make_lstm_pattern(enter_or_id="Enter", from_keras=False, use_bias=False):
85+
def make_lstm_pattern(enter_or_id="Enter", from_keras=False, use_bias=False,
86+
activation="", recurrent_activation=""):
6787
# split (Xt*(W[ifco]^T) + Ht-1*(R[ifco]^T)) on 'Const' axis
6888
lstm_xc_pattern = OpTypePattern('Split', inputs=[
6989
OpTypePattern("Const"),
@@ -77,23 +97,21 @@ def make_lstm_pattern(enter_or_id="Enter", from_keras=False, use_bias=False):
7797
OpTypePattern("*", name="ft_bias"),
7898
])
7999

80-
activation = "Tanh|Relu|Sigmoid"
81-
recurrent_activation = "Tanh|Relu|Sigmoid"
82-
83-
return OpTypePattern("Mul", name='ht', inputs=[
84-
OpTypePattern(recurrent_activation, name="ot", inputs=[lstm_xc_pattern]),
85-
OpTypePattern(activation, name="ct'", inputs=[
86-
OpTypePattern("Add|AddV2", name="ct", inputs=[
87-
OpTypePattern("Mul", name="ct_identity_consumer", inputs=[
88-
OpTypePattern(recurrent_activation, name="ft", inputs=[lstm_fb_pattern]),
89-
OpTypePattern("*", name="c"),
90-
]),
91-
OpTypePattern("Mul", inputs=[
92-
OpTypePattern(recurrent_activation, name="it", inputs=[lstm_xc_pattern]),
93-
OpTypePattern(activation, name="gt", inputs=[lstm_xc_pattern]),
94-
]),
95-
]),
100+
# cell state
101+
lstm_ct_pattern = OpTypePattern("Add|AddV2", name="ct", inputs=[
102+
OpTypePattern("Mul", name="ct_identity_consumer", inputs=[
103+
insert_activation(recurrent_activation, name="ft", inputs=[lstm_fb_pattern]),
104+
OpTypePattern("*", name="c"),
96105
]),
106+
OpTypePattern("Mul", inputs=[
107+
insert_activation(recurrent_activation, name="it", inputs=[lstm_xc_pattern]),
108+
insert_activation(activation, name="gt", inputs=[lstm_xc_pattern]),
109+
]),
110+
])
111+
112+
return OpTypePattern("Mul", name="ht", inputs=[
113+
insert_activation(recurrent_activation, name="ot", inputs=[lstm_xc_pattern]),
114+
insert_activation(activation, name="ct'", inputs=[lstm_ct_pattern]),
97115
])
98116

99117
lstmcell_pattern = make_lstm_pattern()

0 commit comments

Comments
 (0)