Skip to content

pad gives wrong results version>=0.4.36 #26888

@ricardoV94

Description

@ricardoV94

Description

This snippet started failing after 0.4.36 and still fails in 0.5.1:

import jax
import numpy as np

x = np.arange(1, 10).reshape(3,3)
np.testing.assert_allclose(
    np.pad(x, mode="constant", pad_width=3, constant_values=(1, 2)),
    jax.numpy.pad(x, mode="constant", pad_width=3, constant_values=(1, 2)),
)
AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0
Mismatched elements: 9 / 81 (11.1%)
Max absolute difference among violations: 1
Max relative difference among violations: 0.5
 ACTUAL: array([[1, 1, 1, 1, 1, 1, 2, 2, 2],
       [1, 1, 1, 1, 1, 1, 2, 2, 2],
       [1, 1, 1, 1, 1, 1, 2, 2, 2],...
 DESIRED: array([[1, 1, 1, 1, 1, 1, 2, 2, 2],
       [1, 1, 1, 1, 1, 1, 2, 2, 2],
       [1, 1, 1, 1, 1, 1, 2, 2, 2],...

System info (python version, jaxlib version, accelerator, etc.)

platform: uname_result(system='Linux', node='fedora', release='6.12.11-200.fc41.x86_64', version='#1 SMP PREEMPT_DYNAMIC Fri Jan 24 04:59:58 UTC 2025', machine='x86_64')

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions