Skip to content

Commit fe1547f

Browse files
fix
1 parent d6f7f40 commit fe1547f

File tree

7 files changed

+53
-87
lines changed

7 files changed

+53
-87
lines changed

python/paddle/base/framework.py

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2495,65 +2495,6 @@ def T(self):
24952495
)
24962496
return out
24972497

2498-
@property
2499-
def mT(self):
2500-
"""
2501-
2502-
Returns a view of this tensor with the last two dimensions transposed.
2503-
2504-
If `n` is the dimensions of `x` , `x.mT` is equivalent to `x.transpose([0, 1, ..., n-1, n-2])`.
2505-
2506-
Examples:
2507-
.. code-block:: python
2508-
2509-
>>> import paddle
2510-
>>> paddle.enable_static()
2511-
2512-
>>> x = paddle.ones(shape=[2, 3, 5])
2513-
>>> x_mT = x.mT
2514-
2515-
>>> exe = paddle.static.Executor()
2516-
>>> x_mT_np = exe.run(paddle.static.default_main_program(), fetch_list=[x_mT])[0]
2517-
>>> print(x_mT_np.shape)
2518-
(2, 5, 3)
2519-
2520-
"""
2521-
if len(self.shape) < 2:
2522-
raise ValueError(
2523-
f"Tenor.ndim({self.ndim}) is required to be greater than or equal to 2."
2524-
)
2525-
2526-
perm = list(range(len(self.shape)))
2527-
perm[-1], perm[-2] = perm[-2], perm[-1]
2528-
2529-
with unique_name.guard(self.block.program._name_generator):
2530-
out = self.block.create_var(
2531-
name=unique_name.generate_with_ignorable_key(
2532-
self.name + ".tmp"
2533-
),
2534-
dtype=self.dtype,
2535-
type=self.type,
2536-
persistable=False,
2537-
stop_gradient=False,
2538-
)
2539-
input_shape = self.block.create_var(
2540-
name=unique_name.generate_with_ignorable_key(
2541-
self.name + ".tmp"
2542-
),
2543-
dtype=self.dtype,
2544-
type=core.VarDesc.VarType.LOD_TENSOR,
2545-
persistable=False,
2546-
stop_gradient=False,
2547-
)
2548-
2549-
self.block.append_op(
2550-
type="transpose2",
2551-
inputs={"X": [self]},
2552-
outputs={"Out": [out], "XShape": [input_shape]},
2553-
attrs={"axis": perm},
2554-
)
2555-
return out
2556-
25572498
def clone(self):
25582499
"""
25592500
Returns a new static Variable, which is the clone of the original static

python/paddle/jit/sot/opcode_translator/executor/variables/basic.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,6 @@ def T(self):
528528

529529
perm = list(range(len(self.meta.shape) - 1, -1, -1))
530530
perm_var = ListVariable(perm, self.graph, tracker=ConstTracker(perm))
531-
assert perm_var is not None
532531
out = self.graph.call_paddle_api(paddle.transpose, self, perm_var)
533532
return out
534533

@@ -546,9 +545,7 @@ def mT(self):
546545

547546
perm = list(range(len(self.meta.shape)))
548547
perm[-1], perm[-2] = perm[-2], perm[-1]
549-
550-
perm_var = ListVariable(perm, self.graph, tracker=ConstTracker(perm))
551-
assert perm_var is not None
548+
perm_var = ListVariable(perm, self.graph, tracker=DummyTracker(self))
552549
out = self.graph.call_paddle_api(paddle.transpose, self, perm_var)
553550
return out
554551

python/paddle/pir/math_op_patch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,8 +502,11 @@ def _mT_(self):
502502
(2, 5, 3)
503503
504504
"""
505-
if len(self.shape) == 1:
506-
return self
505+
if len(self.shape) < 2:
506+
raise ValueError(
507+
f"Tensor.ndim({len(self.shape)}) is required to be greater than or equal to 2."
508+
)
509+
507510
perm = list(range(len(self.shape)))
508511
perm[-1], perm[-2] = perm[-2], perm[-1]
509512

test/legacy_test/test_math_op_patch.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -375,25 +375,6 @@ def test_T(self):
375375
)
376376
np.testing.assert_array_equal(out[0], out_np)
377377

378-
@prog_scope()
379-
def test_mT(self):
380-
x_np = np.random.randint(-100, 100, [2, 8, 5, 3]).astype("int32")
381-
out_np = x_np.transpose([0, 1, 3, 2])
382-
383-
x = paddle.static.data(name="x", shape=[2, 8, 5, 3], dtype="int32")
384-
z = x.mT
385-
386-
exe = base.Executor()
387-
out = exe.run(
388-
base.default_main_program(), feed={"x": x_np}, fetch_list=[z]
389-
)
390-
np.testing.assert_array_equal(out[0], out_np)
391-
392-
x_np = np.random.randint(-100, 100, [2]).astype("int32")
393-
394-
x = paddle.static.data(name="x", shape=[2, 8, 5, 3], dtype="int32")
395-
self.assertRaises(ValueError, getattr, x, "mT")
396-
397378
@prog_scope()
398379
def test_ndim(self):
399380
a = paddle.static.data(name="a", shape=[10, 1])

test/legacy_test/test_math_op_patch_pir.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,50 @@ def test_mT(self):
483483
(output_x,) = exe.run(main_program, fetch_list=[x_mT])
484484
self.assertEqual(output_x.shape, tuple(out_shape))
485485

486+
shape = [1, 2, 3, 0, 1]
487+
out_shape = list(shape)
488+
out_shape[-2], out_shape[-1] = out_shape[-1], out_shape[-2]
489+
main_program, exe, program_guard = new_program()
490+
with program_guard:
491+
x = paddle.rand(shape, dtype="float32")
492+
x_mT = x.mT
493+
self.assertEqual(x_mT.shape, out_shape)
494+
(output_x,) = exe.run(main_program, fetch_list=[x_mT])
495+
self.assertEqual(output_x.shape, tuple(out_shape))
496+
497+
shape = [1, 2, 3, 1, 0]
498+
out_shape = list(shape)
499+
out_shape[-2], out_shape[-1] = out_shape[-1], out_shape[-2]
500+
main_program, exe, program_guard = new_program()
501+
with program_guard:
502+
x = paddle.rand(shape, dtype="float32")
503+
x_mT = x.mT
504+
self.assertEqual(x_mT.shape, out_shape)
505+
(output_x,) = exe.run(main_program, fetch_list=[x_mT])
506+
self.assertEqual(output_x.shape, tuple(out_shape))
507+
508+
shape = [1, 2, 3, 0, 0]
509+
out_shape = list(shape)
510+
out_shape[-2], out_shape[-1] = out_shape[-1], out_shape[-2]
511+
main_program, exe, program_guard = new_program()
512+
with program_guard:
513+
x = paddle.rand(shape, dtype="float32")
514+
x_mT = x.mT
515+
self.assertEqual(x_mT.shape, out_shape)
516+
(output_x,) = exe.run(main_program, fetch_list=[x_mT])
517+
self.assertEqual(output_x.shape, tuple(out_shape))
518+
519+
shape = [0, 2, 3, 0, 0]
520+
out_shape = list(shape)
521+
out_shape[-2], out_shape[-1] = out_shape[-1], out_shape[-2]
522+
main_program, exe, program_guard = new_program()
523+
with program_guard:
524+
x = paddle.rand(shape, dtype="float32")
525+
x_mT = x.mT
526+
self.assertEqual(x_mT.shape, out_shape)
527+
(output_x,) = exe.run(main_program, fetch_list=[x_mT])
528+
self.assertEqual(output_x.shape, tuple(out_shape))
529+
486530
def test_hash(self):
487531
with paddle.pir_utils.IrGuard():
488532
_, _, program_guard = new_program()

test/legacy_test/test_math_op_patch_var_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,7 @@ def test_tensor_patch_method(self):
828828
x_np = np.random.randn(3, 6, 9, 7)
829829
x = paddle.to_tensor(x_np)
830830
x_mT = x.mT
831-
self.assertTrue(x_mT.shape, [7, 9, 6, 3])
831+
self.assertTrue(x_mT.shape, [3, 6, 7, 9])
832832
np.testing.assert_array_equal(
833833
x_mT.numpy(), x_np.transpose([0, 1, 3, 2])
834834
)

test/sot/test_18_tensor_method.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ def test_middle_tensor_name(self):
9595
self.assert_results(middle_tensor_name, x, y)
9696

9797
def test_tensor_method_property_mT(self):
98-
x = paddle.rand([42, 24], dtype='float64')
99-
y = paddle.rand([42, 24], dtype='float32')
98+
x = paddle.rand([42, 24, 2, 2, 3, 2], dtype='float64')
99+
y = paddle.rand([42, 24, 2, 3, 3, 2], dtype='float32')
100100
self.assert_results(tensor_method_property_mT, x)
101101
self.assert_results(tensor_method_property_mT, y)
102102

0 commit comments

Comments
 (0)