Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/paddle/base/dygraph/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,17 @@ def _T_(var: Tensor) -> Tensor:
out = _C_ops.transpose(var, perm)
return out

@property
def _mT_(var: Tensor) -> Tensor:
if len(var.shape) < 2:
raise ValueError(
f"Tensor.ndim({var.ndim}) is required to be greater than or equal to 2."
)
perm = list(range(len(var.shape)))
perm[-1], perm[-2] = perm[-2], perm[-1]
out = _C_ops.transpose(var, perm)
return out

eager_methods = [
('__neg__', _neg_),
('__abs__', _abs_),
Expand All @@ -203,6 +214,7 @@ def _T_(var: Tensor) -> Tensor:
('ndim', _ndim),
('size', _size_),
('T', _T_),
('mT', _mT_),
# for logical compare
('__array_ufunc__', None),
]
Expand Down
1 change: 1 addition & 0 deletions python/paddle/base/dygraph/tensor_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def _to_static_var(self, to_parameter=False, **kwargs):
attr_not_need_keys = [
'grad',
'T',
'mT',
'place',
'_place_str',
'data',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,24 @@ def T(self):

perm = list(range(len(self.meta.shape) - 1, -1, -1))
perm_var = ListVariable(perm, self.graph, tracker=ConstTracker(perm))
assert perm_var is not None
out = self.graph.call_paddle_api(paddle.transpose, self, perm_var)
return out

@tensor_property
def mT(self):
"""
Return a new TensorVariable object that wraps the result of calling the mT method on the wrapped value of this TensorVariable.
"""
from .container import ListVariable

if len(self.meta.shape) < 2:
raise ValueError(
f"Variable.ndim({self.ndim}) is required to be greater than or equal to 2."
)

perm = list(range(len(self.meta.shape)))
perm[-1], perm[-2] = perm[-2], perm[-1]
perm_var = ListVariable(perm, self.graph, tracker=DummyTracker([self]))
out = self.graph.call_paddle_api(paddle.transpose, self, perm_var)
return out

Expand Down
34 changes: 34 additions & 0 deletions python/paddle/pir/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,39 @@ def _T_(self):

return _C_ops.transpose(self, perm)

@property
def _mT_(self):
"""

Permute current Value with its last two dimensions reversed.

If `n` is the dimensions of `x` , `x.mT` is equivalent to `x.transpose([0, 1, ..., n-1, n-2])`.

Examples:
.. code-block:: python

>>> import paddle
>>> paddle.enable_static()

>>> x = paddle.ones(shape=[2, 3, 5])
>>> x_mT = x.mT

>>> exe = paddle.static.Executor()
>>> x_mT_np = exe.run(paddle.static.default_main_program(), fetch_list=[x_mT])[0]
>>> print(x_mT_np.shape)
(2, 5, 3)

"""
if len(self.shape) < 2:
raise ValueError(
f"Tensor.ndim({len(self.shape)}) is required to be greater than or equal to 2."
)

perm = list(range(len(self.shape)))
perm[-1], perm[-2] = perm[-2], perm[-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

静态图相关单测需要验证在包含动态 shape 场景下的正确性,包括 shape[-1] = 0shape[-2] = 0shape[-1] = 0shape[-2] = 0 的情况


return _C_ops.transpose(self, perm)

def _int_(self):
error_msg = """\
int(Tensor) is not supported in static graph mode. Because it's value is not available during the static mode.
Expand Down Expand Up @@ -973,6 +1006,7 @@ def register_hook(self, hook):
('astype', astype),
('size', _size_),
('T', _T_),
('mT', _mT_),
('clone', clone),
('clear_gradient', clear_gradient),
('append', append),
Expand Down
63 changes: 63 additions & 0 deletions test/legacy_test/test_math_op_patch_pir.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,69 @@ def test_T(self):
(output_x,) = exe.run(main_program, fetch_list=[x_T])
self.assertEqual(output_x.shape, tuple(out_shape))

def test_mT(self):
with paddle.pir_utils.IrGuard():
shape = [1]
x = paddle.rand(shape, dtype="float32")
self.assertRaises(ValueError, getattr, x, 'mT')

for ndim in range(2, 5):
# shape is [1, 2], [1, 2, 3], [1, 2, 3, 4]
shape = list(range(1, ndim + 1))
out_shape = list(shape)
out_shape[-2], out_shape[-1] = out_shape[-1], out_shape[-2]
main_program, exe, program_guard = new_program()
with program_guard:
x = paddle.rand(shape, dtype="float32")
x_mT = x.mT
self.assertEqual(x_mT.shape, out_shape)
(output_x,) = exe.run(main_program, fetch_list=[x_mT])
self.assertEqual(output_x.shape, tuple(out_shape))

shape = [1, 2, 3, 0, 1]
out_shape = list(shape)
out_shape[-2], out_shape[-1] = out_shape[-1], out_shape[-2]
main_program, exe, program_guard = new_program()
with program_guard:
x = paddle.rand(shape, dtype="float32")
x_mT = x.mT
self.assertEqual(x_mT.shape, out_shape)
(output_x,) = exe.run(main_program, fetch_list=[x_mT])
self.assertEqual(output_x.shape, tuple(out_shape))

shape = [1, 2, 3, 1, 0]
out_shape = list(shape)
out_shape[-2], out_shape[-1] = out_shape[-1], out_shape[-2]
main_program, exe, program_guard = new_program()
with program_guard:
x = paddle.rand(shape, dtype="float32")
x_mT = x.mT
self.assertEqual(x_mT.shape, out_shape)
(output_x,) = exe.run(main_program, fetch_list=[x_mT])
self.assertEqual(output_x.shape, tuple(out_shape))

shape = [1, 2, 3, 0, 0]
out_shape = list(shape)
out_shape[-2], out_shape[-1] = out_shape[-1], out_shape[-2]
main_program, exe, program_guard = new_program()
with program_guard:
x = paddle.rand(shape, dtype="float32")
x_mT = x.mT
self.assertEqual(x_mT.shape, out_shape)
(output_x,) = exe.run(main_program, fetch_list=[x_mT])
self.assertEqual(output_x.shape, tuple(out_shape))

shape = [0, 2, 3, 0, 0]
out_shape = list(shape)
out_shape[-2], out_shape[-1] = out_shape[-1], out_shape[-2]
main_program, exe, program_guard = new_program()
with program_guard:
x = paddle.rand(shape, dtype="float32")
x_mT = x.mT
self.assertEqual(x_mT.shape, out_shape)
(output_x,) = exe.run(main_program, fetch_list=[x_mT])
self.assertEqual(output_x.shape, tuple(out_shape))

def test_hash(self):
with paddle.pir_utils.IrGuard():
_, _, program_guard = new_program()
Expand Down
12 changes: 12 additions & 0 deletions test/legacy_test/test_math_op_patch_var_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,18 @@ def test_tensor_patch_method(self):
self.assertTrue(x_T.shape, [7, 9, 6, 3])
np.testing.assert_array_equal(x_T.numpy(), x_np.T)

x_np = np.random.randn(3, 6, 9, 7)
x = paddle.to_tensor(x_np)
x_mT = x.mT
self.assertTrue(x_mT.shape, [3, 6, 7, 9])
np.testing.assert_array_equal(
x_mT.numpy(), x_np.transpose([0, 1, 3, 2])
)

x_np = np.random.randn(3)
x = paddle.to_tensor(x_np)
self.assertRaises(ValueError, getattr, x, "mT")

self.assertTrue(inspect.ismethod(a.dot))
self.assertTrue(inspect.ismethod(a.logsumexp))
self.assertTrue(inspect.ismethod(a.multiplex))
Expand Down
10 changes: 10 additions & 0 deletions test/sot/test_18_tensor_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def tensor_method_property(a: paddle.Tensor, b: paddle.Tensor):
)


def tensor_method_property_mT(a: paddle.Tensor):
return a.mT


def middle_tensor_name(a: paddle.Tensor, b: paddle.Tensor):
c = a + b
return c.name
Expand Down Expand Up @@ -90,6 +94,12 @@ def test_middle_tensor_name(self):
y = paddle.rand([42, 24])
self.assert_results(middle_tensor_name, x, y)

def test_tensor_method_property_mT(self):
x = paddle.rand([42, 24, 2, 2, 3, 2], dtype='float64')
y = paddle.rand([42, 24, 2, 3, 3, 2], dtype='float32')
self.assert_results(tensor_method_property_mT, x)
self.assert_results(tensor_method_property_mT, y)


if __name__ == "__main__":
unittest.main()