Skip to content

Commit 8699aeb

Browse files
committed
style: improve func name for error tracing
Signed-off-by: nstarman <[email protected]>
1 parent a5163f9 commit 8699aeb

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

quax/examples/zero/_core.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def _(value: Zero, *, broadcast_dimensions, shape, sharding=None) -> Zero:
6161

6262

6363
@quax.register(lax.convert_element_type_p)
64-
def _(value: Zero, *, new_dtype, weak_type, sharding=None) -> Zero:
64+
def convert_element_type_zero(
65+
value: Zero, *, new_dtype, weak_type, sharding=None
66+
) -> Zero:
6567
# sharding was added around JAX 0.4.31, it seems.
6668
del weak_type, sharding
6769
return Zero(value.shape, new_dtype)

tests/myarray.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -368,11 +368,8 @@ def conv_general_dilated_p(
368368

369369

370370
@register(lax.convert_element_type_p)
371-
def convert_element_type_p(operand: MyArray, **kw: Any) -> MyArray:
372-
return replace(
373-
operand,
374-
array=lax.convert_element_type_p.bind(operand.array, **kw),
375-
)
371+
def convert_element_type_myarray(operand: MyArray, **kw: Any) -> MyArray:
372+
return replace(operand, array=lax.convert_element_type_p.bind(operand.array, **kw))
376373

377374

378375
# ==============================================================================

0 commit comments

Comments
 (0)