Skip to content

Commit ae4c39e

Browse files
Fix unsupported ops TF 2.14.0: OnesLike (#2270)
* add OnesLike handler * add tests for OnesLike --------- Signed-off-by: Alexander Gerstenberger <[email protected]> Co-authored-by: Alexander Gerstenberger <[email protected]> Co-authored-by: Jay Zhang <[email protected]>
1 parent 278bf8a commit ae4c39e

File tree

2 files changed

+47
-18
lines changed

2 files changed

+47
-18
lines changed

tests/test_backend.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4113,6 +4113,16 @@ def func(x, y):
41134113
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x.astype(np.int32), _INPUT1: input_y}, as_session=True,
41144114
graph_validator=lambda g: check_op_count(g, "ConstantOfShape", 1, disabled=False))
41154115

4116+
def test_ones_like(self):
4117+
input_x = np.random.random_sample([3, 16, 16]).astype(np.float32)
4118+
input_y = np.array([16, 16, 3]).astype(np.int64)
4119+
4120+
def func(x, y):
4121+
z = tf.reshape(x, y)
4122+
return tf.ones_like(z, name=_TFOUTPUT)
4123+
4124+
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x, _INPUT1: input_y})
4125+
41164126
@check_opset_min_version(9, "is_nan")
41174127
def test_isnan(self):
41184128
# only compatible with dtype `float32`

tf2onnx/onnx_opset/generator.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -227,31 +227,50 @@ def version_7(cls, ctx, node, **kwargs):
227227
ctx.remove_input(node, node.input[1], 1)
228228

229229

230+
def _const_like_version_1(ctx, node, value):
231+
shapes = node.output_shapes
232+
dtypes = node.output_dtypes
233+
ctx.remove_node(node.name)
234+
casted_input = ctx.make_node("Cast", node.input, attr={'to': onnx_pb.TensorProto.INT64})
235+
const_value = ctx.make_const(utils.make_name("value"), np.array(value).astype(np.int64))
236+
mul_node = ctx.make_node('Mul', inputs=[casted_input.output[0], const_value.output[0]])
237+
ctx.make_node("Cast", inputs=[mul_node.output[0]],
238+
attr={'to': dtypes[0]},
239+
name=node.name, outputs=node.output,
240+
shapes=shapes, dtypes=dtypes)
241+
242+
243+
def _const_like_version_9(ctx, node, value):
244+
dtypes = node.output_dtypes
245+
ctx.remove_node(node.name)
246+
shape = ctx.make_node("Shape", node.input).output[0]
247+
value_tensor = helper.make_tensor("value", dtypes[0], [1], vals=[value])
248+
ctx.make_node("ConstantOfShape", inputs=[shape],
249+
attr={'value': value_tensor},
250+
name=node.name, outputs=node.output,
251+
dtypes=dtypes)
252+
253+
230254
@tf_op("ZerosLike")
231255
class ZerosLike:
232256
@classmethod
233257
def version_1(cls, ctx, node, **kwargs):
234-
shapes = node.output_shapes
235-
dtypes = node.output_dtypes
236-
ctx.remove_node(node.name)
237-
casted_input = ctx.make_node("Cast", node.input, attr={'to': onnx_pb.TensorProto.INT64})
238-
const_zero = ctx.make_const(utils.make_name("zero"), np.array(0).astype(np.int64))
239-
mul_node = ctx.make_node('Mul', inputs=[casted_input.output[0], const_zero.output[0]])
240-
ctx.make_node("Cast", inputs=[mul_node.output[0]],
241-
attr={'to': dtypes[0]},
242-
name=node.name, outputs=node.output,
243-
shapes=shapes, dtypes=dtypes)
258+
_const_like_version_1(ctx, node, 0)
244259

245260
@classmethod
246261
def version_9(cls, ctx, node, **kwargs):
247-
dtypes = node.output_dtypes
248-
ctx.remove_node(node.name)
249-
shape = ctx.make_node("Shape", node.input).output[0]
250-
zero_tensor = helper.make_tensor("value", dtypes[0], [1], vals=[0])
251-
ctx.make_node("ConstantOfShape", inputs=[shape],
252-
attr={'value': zero_tensor},
253-
name=node.name, outputs=node.output,
254-
dtypes=dtypes)
262+
_const_like_version_9(ctx, node, 0)
263+
264+
265+
@tf_op("OnesLike")
266+
class OnesLike:
267+
@classmethod
268+
def version_1(cls, ctx, node, **kwargs):
269+
_const_like_version_1(ctx, node, 1)
270+
271+
@classmethod
272+
def version_9(cls, ctx, node, **kwargs):
273+
_const_like_version_9(ctx, node, 1)
255274

256275

257276
@tf_op(["IteratorV2", "FIFOQueueV2"])

0 commit comments

Comments
 (0)