@@ -233,16 +233,17 @@ def primitive_computation(prim, axis_env, backend, tuple_args, *avals, **params)
233
233
# return val always set as a side-effect on c
234
234
if prim in backend_specific_translations [platform ]:
235
235
rule = backend_specific_translations [platform ][prim ]
236
- rule (c , * xla_args , ** params )
236
+ ans = rule (c , * xla_args , ** params )
237
237
elif prim in translations :
238
238
rule = translations [prim ]
239
- rule (c , * xla_args , ** params )
239
+ ans = rule (c , * xla_args , ** params )
240
240
elif prim in initial_style_translations :
241
241
rule = initial_style_translations [prim ]
242
- rule (c , axis_env , extend_name_stack (prim .name ), avals , backend ,
242
+ ans = rule (c , axis_env , extend_name_stack (prim .name ), avals , backend ,
243
243
* xla_args , ** params )
244
244
else :
245
245
raise NotImplementedError ("XLA translation rule for {} not found" .format (prim ))
246
+ assert isinstance (ans , xc ._xla .XlaOp )
246
247
c .ClearOpMetadata ()
247
248
try :
248
249
return c .Build ()
@@ -355,6 +356,7 @@ def write(v, node):
355
356
msg = "XLA translation rule for primitive '{}' not found"
356
357
raise NotImplementedError (msg .format (eqn .primitive .name ))
357
358
359
+ assert isinstance (ans , xc ._xla .XlaOp )
358
360
c .GetShape (ans ) # force xla to do shape error checking
359
361
out_nodes = xla_destructure (c , ans ) if eqn .primitive .multiple_results else [ans ]
360
362
c .ClearOpMetadata ()
0 commit comments