Skip to content

Commit 6e23f3f

Browse files
author
Me
committed
follow bigru_rewriter syntax
Signed-off-by: Me <[email protected]>
1 parent 18f1cd2 commit 6e23f3f

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tf2onnx/rewriter/bilstm_rewriter.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,18 @@ def process_bilstm(g, bi_lstms):
5454
if len(lstm_fw.inputs) > 4:
5555
lstm_inputs.extend([lstm_fw.input[4], h_node.output[0], c_node.output[0]])
5656

57-
attr = {
58-
"direction": "bidirectional",
59-
"activations": lstm_bw.get_attr_value("activations") + lstm_fw.get_attr_value("activations"),
60-
}
57+
direction = "bidirectional"
58+
attr = {}
6159
for name in rnn_utils.onnx_rnn_attr_mapping[rnn_utils.ONNX_RNN_TYPE.LSTM]:
6260
attr_val = lstm_fw.get_attr_value(name)
6361
if attr_val:
6462
attr[name] = attr_val
63+
# activation has to be took care, attr here is proto
64+
activations = [act.decode("utf-8")
65+
for act in lstm_bw.get_attr_value("activations")]
66+
activations += [act.decode("utf-8")
67+
for act in lstm_fw.get_attr_value("activations")]
68+
attr.update({"direction": direction, "activations": activations})
6569

6670
bi_lstm_node = g.make_node("LSTM", lstm_inputs, attr=attr, output_count=3)
6771
all_nodes.append(bi_lstm_node)

0 commit comments

Comments
 (0)