Skip to content

Commit 87d9590

Browse files
authored
Add a dynamic type check that the value returned by an XLA translation rule is an XlaOp. (#2723)
Helps give a more understandable error on erroneous translation rules.
1 parent 821193b commit 87d9590

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

jax/interpreters/xla.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,16 +233,17 @@ def primitive_computation(prim, axis_env, backend, tuple_args, *avals, **params)
233233
# return val always set as a side-effect on c
234234
if prim in backend_specific_translations[platform]:
235235
rule = backend_specific_translations[platform][prim]
236-
rule(c, *xla_args, **params)
236+
ans = rule(c, *xla_args, **params)
237237
elif prim in translations:
238238
rule = translations[prim]
239-
rule(c, *xla_args, **params)
239+
ans = rule(c, *xla_args, **params)
240240
elif prim in initial_style_translations:
241241
rule = initial_style_translations[prim]
242-
rule(c, axis_env, extend_name_stack(prim.name), avals, backend,
242+
ans = rule(c, axis_env, extend_name_stack(prim.name), avals, backend,
243243
*xla_args, **params)
244244
else:
245245
raise NotImplementedError("XLA translation rule for {} not found".format(prim))
246+
assert isinstance(ans, xc._xla.XlaOp)
246247
c.ClearOpMetadata()
247248
try:
248249
return c.Build()
@@ -355,6 +356,7 @@ def write(v, node):
355356
msg = "XLA translation rule for primitive '{}' not found"
356357
raise NotImplementedError(msg.format(eqn.primitive.name))
357358

359+
assert isinstance(ans, xc._xla.XlaOp)
358360
c.GetShape(ans) # force xla to do shape error checking
359361
out_nodes = xla_destructure(c, ans) if eqn.primitive.multiple_results else [ans]
360362
c.ClearOpMetadata()

0 commit comments

Comments
 (0)