@@ -227,31 +227,50 @@ def version_7(cls, ctx, node, **kwargs):
227
227
ctx .remove_input (node , node .input [1 ], 1 )
228
228
229
229
230
+ def _const_like_version_1 (ctx , node , value ):
231
+ shapes = node .output_shapes
232
+ dtypes = node .output_dtypes
233
+ ctx .remove_node (node .name )
234
+ casted_input = ctx .make_node ("Cast" , node .input , attr = {'to' : onnx_pb .TensorProto .INT64 })
235
+ const_value = ctx .make_const (utils .make_name ("value" ), np .array (value ).astype (np .int64 ))
236
+ mul_node = ctx .make_node ('Mul' , inputs = [casted_input .output [0 ], const_value .output [0 ]])
237
+ ctx .make_node ("Cast" , inputs = [mul_node .output [0 ]],
238
+ attr = {'to' : dtypes [0 ]},
239
+ name = node .name , outputs = node .output ,
240
+ shapes = shapes , dtypes = dtypes )
241
+
242
+
243
+ def _const_like_version_9 (ctx , node , value ):
244
+ dtypes = node .output_dtypes
245
+ ctx .remove_node (node .name )
246
+ shape = ctx .make_node ("Shape" , node .input ).output [0 ]
247
+ value_tensor = helper .make_tensor ("value" , dtypes [0 ], [1 ], vals = [value ])
248
+ ctx .make_node ("ConstantOfShape" , inputs = [shape ],
249
+ attr = {'value' : value_tensor },
250
+ name = node .name , outputs = node .output ,
251
+ dtypes = dtypes )
252
+
253
+
230
254
@tf_op ("ZerosLike" )
231
255
class ZerosLike :
232
256
@classmethod
233
257
def version_1 (cls , ctx , node , ** kwargs ):
234
- shapes = node .output_shapes
235
- dtypes = node .output_dtypes
236
- ctx .remove_node (node .name )
237
- casted_input = ctx .make_node ("Cast" , node .input , attr = {'to' : onnx_pb .TensorProto .INT64 })
238
- const_zero = ctx .make_const (utils .make_name ("zero" ), np .array (0 ).astype (np .int64 ))
239
- mul_node = ctx .make_node ('Mul' , inputs = [casted_input .output [0 ], const_zero .output [0 ]])
240
- ctx .make_node ("Cast" , inputs = [mul_node .output [0 ]],
241
- attr = {'to' : dtypes [0 ]},
242
- name = node .name , outputs = node .output ,
243
- shapes = shapes , dtypes = dtypes )
258
+ _const_like_version_1 (ctx , node , 0 )
244
259
245
260
@classmethod
246
261
def version_9 (cls , ctx , node , ** kwargs ):
247
- dtypes = node .output_dtypes
248
- ctx .remove_node (node .name )
249
- shape = ctx .make_node ("Shape" , node .input ).output [0 ]
250
- zero_tensor = helper .make_tensor ("value" , dtypes [0 ], [1 ], vals = [0 ])
251
- ctx .make_node ("ConstantOfShape" , inputs = [shape ],
252
- attr = {'value' : zero_tensor },
253
- name = node .name , outputs = node .output ,
254
- dtypes = dtypes )
262
+ _const_like_version_9 (ctx , node , 0 )
263
+
264
+
265
+ @tf_op ("OnesLike" )
266
+ class OnesLike :
267
+ @classmethod
268
+ def version_1 (cls , ctx , node , ** kwargs ):
269
+ _const_like_version_1 (ctx , node , 1 )
270
+
271
+ @classmethod
272
+ def version_9 (cls , ctx , node , ** kwargs ):
273
+ _const_like_version_9 (ctx , node , 1 )
255
274
256
275
257
276
@tf_op (["IteratorV2" , "FIFOQueueV2" ])
0 commit comments