Skip to content

Commit 633beaa

Browse files
authored
fix autoformer (#2066)
1 parent 0606b78 commit 633beaa

File tree

6 files changed

+120
-5
lines changed

6 files changed

+120
-5
lines changed

mindnlp/core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from .amp import autocast, GradScaler
4848

4949
from . import profiler, cuda, optim, amp, compiler, jit, version, __future__, overrides, \
50-
return_types, linalg, fx, backends, testing, nn
50+
return_types, linalg, fx, backends, testing, nn, fft
5151

5252
from ._lowrank import svd_lowrank
5353
from .random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state

mindnlp/core/_tensor.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import math
22
import numpy as np
3+
from functools import partial
34
import mindspore
45
from mindspore import Tensor
56
from mindspore.common.tensor import _TensorMeta
@@ -226,8 +227,42 @@ def __getitem__(self, slices):
226227
Tensor.__getitem__ = __getitem__
227228
StubTensor.__getitem__ = __getitem__
228229

230+
def _convert_numpy_slices(self, key):
231+
"""递归转换 key 中的 NumPy 整数为内置 int"""
232+
# 处理元组:遍历所有元素并递归转换
233+
if isinstance(key, tuple):
234+
return tuple(self._convert_numpy_slices(k) for k in key)
235+
236+
# 处理 slice 对象:转换 start/stop/step
237+
elif isinstance(key, slice):
238+
start = key.start
239+
stop = key.stop
240+
step = key.step
241+
242+
# 转换 NumPy 整数为 Python int
243+
if isinstance(start, np.integer):
244+
start = int(start)
245+
if isinstance(stop, np.integer):
246+
stop = int(stop)
247+
if isinstance(step, np.integer):
248+
step = int(step)
249+
250+
return slice(start, stop, step)
251+
252+
# 转换单个 NumPy 索引值
253+
elif isinstance(key, np.integer):
254+
return int(key)
255+
256+
# 其他类型(如 int、None)直接返回
257+
else:
258+
return key
259+
260+
Tensor._convert_numpy_slices = _convert_numpy_slices
261+
StubTensor._convert_numpy_slices = _convert_numpy_slices
262+
229263
origin_setitem = Tensor.__setitem__
230264
def __setitem__(self, slices, value):
265+
slices = self._convert_numpy_slices(slices)
231266
if isinstance(value, float):
232267
if value == float('inf'):
233268
value = ops.finfo(self.dtype).max
@@ -399,6 +434,14 @@ def __rmul__(self, other):
399434
Tensor.__rmul__ = __rmul__
400435
StubTensor.__rmul__ = __rmul__
401436

437+
Tensor.norm = ops.norm
438+
StubTensor.norm = ops.norm
439+
440+
def clamp_min(self, value):
441+
return ops.clamp(self, value)
442+
Tensor.clamp_min = clamp_min
443+
StubTensor.clamp_min = clamp_min
444+
402445
def _rebuild_from_type_v2(func, new_type, args, state):
403446
ret = func(*args)
404447
return ret

mindnlp/core/fft/__init__.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""fft"""
2+
from mindspore import ops
3+
from mindspore.ops._primitive_cache import _get_cache_prim
4+
from ..configs import use_pyboost
5+
from ..ops import narrow
6+
from ..nn import functional as F
7+
8+
def rfft(input, n=None, dim=-1, norm="backward"):
9+
if use_pyboost():
10+
return ops.rfft(input, n, dim, norm)
11+
if input.shape[dim] < n:
12+
pad_inf = (0, n - input.shape[dim])
13+
pad_dims = (0, 0) * (input.ndim - (dim + 1)) + pad_inf
14+
input = F.pad(input, pad_dims)
15+
else:
16+
input = narrow(input, dim, 0, n)
17+
_rfft = _get_cache_prim(ops.FFTWithSize)(input.ndim, False, True, norm)
18+
return _rfft(input)
19+
20+
def irfft(input, n=None, dim=-1, norm="backward"):
21+
if use_pyboost():
22+
return ops.irfft(input, n, dim, norm)
23+
if input.shape[dim] < n:
24+
pad_inf = (0, n - input.shape[dim])
25+
pad_dims = (0, 0) * (input.ndim - (dim + 1)) + pad_inf
26+
input = pad(input, pad_dims)
27+
else:
28+
input = narrow(input, dim, 0, n)
29+
_irfft = _get_cache_prim(ops.FFTWithSize)(input.ndim, True, True, norm)
30+
return _irfft(input)
31+
32+
def fftn(input, s=None, dim=None, norm=None):
33+
return ops.fftn(input, s, dim, norm)
34+
35+
def fft(input, s=None, dim=-1, norm=None):
36+
return ops.fft(input, s, dim, norm)
37+
38+
__all__ = ['fft', 'fftn', 'irfft', 'rfft']

mindnlp/core/nn/functional.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,51 @@ def apply_rotary_pos_emb(query, key, cos, sin, position_ids, cos_format=0):
246246
query, key, cos, sin, position_ids, cos_format
247247
)
248248

249+
def custom_circular_pad(x, pad):
250+
"""手动实现 torch.nn.functional.pad 的 circular 模式。
251+
252+
参数:
253+
x: 输入张量,形状为 (B, C, D1, D2, ...)
254+
pad: 填充参数,格式为 (left_N, right_N, left_{N-1}, right_{N-1}, ..., left_1, right_1)
255+
表示从最后维度开始向前定义填充大小
256+
257+
返回:
258+
循环填充后的张量
259+
"""
260+
ndim = x.dim()
261+
n_pad_dims = len(pad) // 2
262+
assert n_pad_dims <= ndim, "填充参数超过了张量的维度"
263+
264+
# 按从最后维度向前处理填充
265+
for dim in range(ndim-1, ndim-1-n_pad_dims, -1):
266+
# 当前维度的左右填充量
267+
idx = 2 * (ndim - 1 - dim) # 在pad元组中的起始位置
268+
left_pad = pad[idx]
269+
right_pad = pad[idx + 1]
270+
271+
if left_pad == 0 and right_pad == 0:
272+
continue # 跳过该维度
273+
274+
size = x.shape[dim] # 当前维度的原始长度
275+
new_size = left_pad + size + right_pad
276+
277+
# 生成循环索引: (index - left_pad) mod size
278+
index = (core.arange(new_size) - left_pad) % size
279+
x = core.index_select(x, dim, index)
280+
281+
return x
282+
249283
def pad(input, pad, mode='constant', value=0.0):
250284
if sum(pad) == 0:
251285
return input
252286
if isinstance(pad, tuple):
253287
pad = tuple(p if isinstance(p, int) else p.item() for p in pad)
254288
if use_pyboost() and not ON_A1:
255289
return mint.nn.functional.pad(input, pad, mode, value)
256-
if mode in ['reflect', 'circular']:
290+
if mode == 'reflect':
257291
return ops.pad(input, pad, mode)
292+
if mode == 'circular':
293+
return custom_circular_pad(input, pad)
258294
new_pad = ()
259295
for idx, pad_v in enumerate(pad):
260296
if pad_v < 0:

mindnlp/core/ops/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from .reduction import *
1111
from .other import *
1212
from .tensor import *
13-
# from .fft_op import *
1413
# from .spectral import *
1514
from ._inner import *
1615
from .optim import *
@@ -27,7 +26,6 @@ def load_library(lib_path):
2726
__all__.extend(blas.__all__)
2827
__all__.extend(comparison.__all__)
2928
__all__.extend(creation.__all__)
30-
# __all__.extend(fft_op.__all__)
3129
__all__.extend(pointwise.__all__)
3230
__all__.extend(random.__all__)
3331
__all__.extend(reduction.__all__)

mindnlp/core/ops/creation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def empty(*size, dtype=None, device=None, requires_grad=False, pin_memory=False,
142142
if device is None:
143143
device= get_default_device()
144144

145-
if isinstance(size[0], (tuple, list)):
145+
if len(size) > 0 and isinstance(size[0], (tuple, list)):
146146
size = size[0]
147147

148148
if dtype is None:

0 commit comments

Comments
 (0)