Skip to content

Commit ab10189

Browse files
authored
fix apis for l-n class (#2157)
1 parent a00b56c commit ab10189

File tree

10 files changed

+119
-29
lines changed

10 files changed

+119
-29
lines changed

mindnlp/core/_prims/ascend.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def pad_v3(input_x, padding, mode='constant', value=None):
108108
pad_op = ops.PadV3(mode=mode, paddings_contiguous=True).set_device('Ascend')
109109
if input_x.dtype == core.bool:
110110
input_x = input_x.to(core.int32)
111+
value = int(value)
111112
out = pad_op(input_x, padding, value)
112113
return cast(out, core.bool)
113114

@@ -117,6 +118,7 @@ def pad_v3(input_x, padding, mode='constant', value=None):
117118

118119
__all__.append('pad_v3')
119120

121+
120122
def inplace_uniform(input, from_, to_, generator_):
121123
seed, offset = generator_._step(12)
122124
return gen_ops_prim.inplace_uniform_op(input, from_, to_, seed, offset)
@@ -413,3 +415,9 @@ def bucketize(input, boundaries, right):
413415
return bucketize_op(input)
414416

415417
__all__.append('bucketize')
418+
419+
def dropout2d(input, p):
420+
dropout_2d_op = ops.Dropout2D(1.0 - p)
421+
return dropout_2d_op(input)
422+
423+
__all__.append('dropout2d')

mindnlp/core/_prims/meta.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,3 +356,17 @@ def bitwise_xor_tensor(input, other):
356356
return input
357357

358358
__all__.append('bitwise_xor_tensor')
359+
360+
def divmod(input, other, rounding_mode):
361+
if isinstance(input, core.Tensor):
362+
return input
363+
return other
364+
365+
__all__.append('divmod')
366+
367+
def greater_equal(input, other):
368+
if isinstance(input, core.Tensor):
369+
return input
370+
return other
371+
372+
__all__.append('greater_equal')

mindnlp/core/_prims/numpy.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,23 @@ def div(input, other):
4343
__all__.append('div')
4444

4545
def pow_scalar_tensor(input, other):
46-
out = np.power(input, other.numpy())
46+
other = other.numpy()
47+
out = np.power(input, other)
48+
if out.dtype == np.float64:
49+
out = out.astype(np.float32)
4750
return core.Tensor.from_numpy(out)
4851

4952
__all__.append('pow_scalar_tensor')
5053

5154
def mul(input, other):
5255
if not isinstance(input, numbers.Number):
5356
input = input.asnumpy()
54-
elif not isinstance(other, numbers.Number):
57+
if not isinstance(other, numbers.Number):
5558
other = other.asnumpy()
56-
out = np.multiply(input, other)
59+
60+
out = input * other
61+
if out.dtype == np.float64:
62+
out = out.astype(np.float32)
5763
if not isinstance(out, np.ndarray):
5864
out = np.array(out)
5965
return core.Tensor.from_numpy(out)
@@ -598,7 +604,10 @@ def inplace_add_ext(input, other, alpha):
598604
__all__.append('inplace_add_ext')
599605

600606
def pow_tensor_scalar(input, other):
601-
out = np.power(input.numpy(), other)
607+
input = input.numpy()
608+
if input.dtype == np.int64:
609+
input = input.astype(np.int32)
610+
out = np.power(input, other)
602611
if not isinstance(out, np.ndarray):
603612
out = np.array(out)
604613
return core.Tensor.from_numpy(out)
@@ -731,8 +740,10 @@ def divmod(input, other, rounding_mode):
731740
if rounding_mode == 'floor':
732741
out = np.floor_divide(input, other)
733742
elif rounding_mode == 'trunc':
734-
out = np.trunc(np.true_divide(input, other))
743+
out = np.trunc(np.true_divide(input, other)).astype(np.int64)
735744

745+
if not isinstance(out, np.ndarray):
746+
out = np.array(out)
736747
return core.Tensor.from_numpy(out)
737748

738749
__all__.append('divmod')
@@ -801,6 +812,12 @@ def repeat_interleave_tensor(input, repeats, dim, _):
801812

802813
__all__.append('repeat_interleave_tensor')
803814

815+
def repeat_interleave_int(input, repeats, dim, _):
816+
out = np.repeat(input.numpy(), repeats, dim)
817+
return core.Tensor.from_numpy(out)
818+
819+
__all__.append('repeat_interleave_int')
820+
804821
def greater(input, other):
805822
if not isinstance(input, numbers.Number):
806823
input = input.numpy()
@@ -823,6 +840,8 @@ def linalg_vector_norm(input, p, dim, keepdim, dtype):
823840

824841
def exp(input):
825842
out = np.exp(input.numpy())
843+
if input.dtype == np.int64:
844+
out = out.astype(np.float32)
826845
return core.Tensor.from_numpy(out)
827846

828847
__all__.append('exp')
@@ -917,3 +936,41 @@ def floor(input):
917936
return core.Tensor.from_numpy(out)
918937

919938
__all__.append('floor')
939+
940+
def chunk(input, chunks, dim):
941+
out = np.array_split(input.numpy(), chunks, dim)
942+
out = [core.Tensor.from_numpy(o) for o in out]
943+
return out
944+
945+
__all__.append('chunk')
946+
947+
def narrow(input, dim, start, length):
948+
slices = [slice(None)] * input.ndim
949+
# 将指定维度的切片修改为 [start: start+length]
950+
slices[dim] = slice(start, start + length)
951+
# 应用切片并返回视图
952+
out = input.numpy()[tuple(slices)]
953+
return core.Tensor.from_numpy(out)
954+
955+
__all__.append('narrow')
956+
957+
def roll(input, shifts, dims):
958+
out = np.roll(input.numpy(), shifts, dims)
959+
return core.Tensor.from_numpy(out)
960+
961+
__all__.append('roll')
962+
963+
def outer(input, other):
964+
out = np.outer(input.numpy(), other.numpy())
965+
return core.Tensor.from_numpy(out)
966+
967+
__all__.append('outer')
968+
969+
def one_hot_ext(tensor, num_classes=-1):
970+
if num_classes == -1:
971+
num_classes = np.max(tensor.numpy()) + 1 # 自动确定类别数[2](@ref)
972+
973+
out = np.eye(num_classes)[tensor.numpy()]
974+
return core.Tensor.from_numpy(out)
975+
976+
__all__.append('one_hot_ext')

mindnlp/core/_tensor.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def cuda(self, device=None, non_blocking=False):
173173

174174
def requires_grad_(self, requires_grad=True):
175175
self.requires_grad = requires_grad
176+
return self
176177

177178
def __reduce_ex__(self, protocol):
178179
if isinstance(self, StubTensor):
@@ -293,6 +294,8 @@ def __rtruediv__ (self, other):
293294
return ops.div(other, self)
294295

295296
def __ne__(self, other):
297+
if isinstance(other, list):
298+
return True
296299
return ops.ne(self, other)
297300

298301
def __neg__(self):
@@ -2122,10 +2125,10 @@ def untyped_storage(self):
21222125

21232126

21242127
# Tensor.stride
2125-
def stride(self, dim=None):
2126-
if dim is None:
2127-
return self._data.stride()
2128-
return self._data.stride()[dim]
2128+
# def stride(self, dim=None):
2129+
# if dim is None:
2130+
# return self.stride()
2131+
# return self.stride()[dim]
21292132

21302133

21312134
# Tensor.sub

mindnlp/core/nn/functional.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from mindnlp import core
88
from mindnlp.core.executor import execute
99

10-
from ..configs import DEVICE_TARGET, ON_ORANGE_PI, use_pyboost, ON_A1, ON_A2
10+
from ..configs import ON_ORANGE_PI, use_pyboost, ON_A1, ON_A2
1111

1212
generator_step_ = 12
1313

@@ -74,7 +74,7 @@ def hardsigmoid(input, inplace=False):
7474
return ops.hardsigmoid(input)
7575

7676
def hardswish(input: core.Tensor, inplace: bool = False) -> core.Tensor:
77-
return ops.hardswish(input)
77+
return execute('hswish', input)
7878

7979
def hardshrink(input, lambd=0.5):
8080
return execute('hard_shrink', input, lambd)
@@ -129,7 +129,7 @@ def adaptive_avg_pool2d(input, output_size):
129129
return execute('adaptive_avg_pool2d_ext', input, output_size)
130130

131131
def dropout(input, p=0.5, training=True, inplace=False):
132-
if not training:
132+
if not training or p==0:
133133
return input
134134
out, _ = execute('dropout_ext', input, p)
135135
if inplace:
@@ -138,7 +138,10 @@ def dropout(input, p=0.5, training=True, inplace=False):
138138
return out
139139

140140
def dropout2d(input, p=0.5, training=False):
141-
return ops.dropout2d(input, p, training)
141+
if not training or p==0:
142+
return input
143+
out, _ = execute('dropout2d', input, p)
144+
return out
142145

143146
def drop_and_mask(keep_prob, seed=None):
144147
seed0, seed1 = _get_seed(seed, "dropout")
@@ -301,6 +304,9 @@ def pad(input, pad, mode='constant', value=None):
301304
return execute('pad_v3', input, new_pad, mode)
302305
if value is None:
303306
value = 0
307+
if mode == "replicate":
308+
mode = "edge"
309+
return execute('pad_v3', input, new_pad, mode)
304310
return execute('pad_v3', input, new_pad, mode, value)
305311
out = input
306312
if (isinstance(pad, tuple) and not pad):
@@ -1541,8 +1547,8 @@ def _canonical_mask(
15411547
) -> Optional[core.Tensor]:
15421548
if mask is not None:
15431549
_mask_dtype = mask.dtype
1544-
_mask_is_float = ops.is_floating_point(mask)
1545-
if _mask_dtype != mindspore.bool_ and not _mask_is_float:
1550+
_mask_is_float = core.is_floating_point(mask)
1551+
if _mask_dtype != core.bool and not _mask_is_float:
15461552
raise AssertionError(
15471553
f"only bool and floating types of {mask_name} are supported")
15481554
if check_other and other_type is not None:
@@ -1552,8 +1558,8 @@ def _canonical_mask(
15521558
"is deprecated. Use same type for both instead."
15531559
)
15541560
if not _mask_is_float:
1555-
zero_tensor = ops.zeros_like(mask, dtype=target_type)
1556-
mask = ops.where(mask, core.Tensor(float("-inf"), target_type), zero_tensor)
1561+
zero_tensor = core.zeros_like(mask, dtype=target_type, device=mask.device)
1562+
mask = core.where(mask, core.tensor(float("-inf"), dtype=target_type, device=mask.device), zero_tensor)
15571563
# mask = (
15581564
# ops.zeros_like(mask, dtype=target_type)
15591565
# .masked_fill_(mask, float("-inf"))
@@ -1571,14 +1577,9 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
15711577
if ON_A1:
15721578
return execute('im2col', input, kernel_size, dilation, padding, stride)
15731579
return execute('im2col_ext', input, kernel_size, dilation, padding, stride)
1574-
if use_pyboost() and not ON_A1:
1575-
return mint.nn.functional.unfold(input, kernel_size, dilation, padding, stride)
1576-
return ops.unfold(input, kernel_size, dilation, padding, stride)
15771580

15781581
def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
1579-
if use_pyboost():
1580-
return mint.nn.functional.fold(input, output_size, kernel_size, dilation, padding, stride)
1581-
return ops.fold(input, output_size, kernel_size, dilation, padding, stride)
1582+
return execute('col2im_ext', input, output_size, kernel_size, dilation, padding, stride)
15821583

15831584
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False):
15841585
return execute('ctc_loss', log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity)

mindnlp/core/npu/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def npu_fusion_attention(query, key, value, head_num, input_layout, *, pse=None,
115115
scale=1., keep_prob=1., pre_tockens=2147483647, next_tockens=2147483647, inner_precise=0,
116116
drop_mask=None, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0,
117117
gen_mask_parallel=True, sync=False, pse_type=1, q_start_idx=None, kv_start_idx=None):
118-
output = gen.flash_attention_score_impl(
118+
output = execute(
119+
'flash_attention_score',
119120
query, key, value, real_shift=pse, padding_mask=padding_mask, drop_mask=drop_mask,
120121
attn_mask=atten_mask, prefix=prefix, actual_seq_qlen=actual_seq_qlen,
121122
actual_seq_kvlen=actual_seq_kvlen, head_num=head_num, keep_prob=float(keep_prob),

mindnlp/core/ops/array.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def _get_moved_perm(ndim, source, destination):
201201

202202
# narrow
203203
def narrow(input, dim, start, length):
204+
length = length.item() if not isinstance(length, int) else length
204205
return execute("narrow", input, dim, start, length)
205206

206207

@@ -393,7 +394,7 @@ def tensor_split(input, indices_or_sections, dim=0):
393394
def tile(input, dims):
394395
if isinstance(dims[0], (tuple, list)):
395396
dims = dims[0]
396-
return execute("tile", input, dims)
397+
return execute("tile", input, tuple(dims))
397398

398399

399400
# transpose

mindnlp/core/ops/creation.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
from .._bind import get_default_dtype, get_device_in_context
1313

1414
def as_strided(self, size, stride, storage_offset=None):
15-
return execute('as_strided', self, size, stride, storage_offset)
15+
size = [s if isinstance(s, int) else s.item() for s in size]
16+
if storage_offset is None:
17+
storage_offset = 0
18+
return execute('as_strided', self, tuple(size), tuple(stride), storage_offset)
1619

1720
# from_numpy
1821
def from_numpy(ndarray):
@@ -37,7 +40,7 @@ def zeros(*size, out=None, dtype=None, layout=None, device=None, requires_grad=F
3740
if device is None:
3841
device = get_device_in_context()
3942

40-
if isinstance(device, str):
43+
if isinstance(device, (str, int)):
4144
device = core.device(device)
4245
if len(size) > 0 and isinstance(size[0], (tuple, list)):
4346
size = size[0]

mindnlp/core/ops/other.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ def clone(input, *, memory_format=core.preserve_format):
9797
# cumprod
9898

9999
# cumsum
100-
def cumsum(input, dim, dtype=None):
100+
def cumsum(input, dim=None, dtype=None, **kwargs):
101+
dim = kwargs.pop('axis', dim)
101102
if input.dtype in [core.int64, core.bool]:
102103
return execute('cumsum_ext', input.int(), dim, None).long()
103104
if dtype is not None and dtype == core.int64:

mindnlp/core/ops/reduction.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def any(input, dim=None, keepdim=False):
4040
return execute('reduce_any', input, dim, keepdim)
4141

4242
# max
43-
def max(input, dim=None, keepdim=False, *, out=None):
43+
def max(input, dim=None, keepdim=False, *, out=None, **kwargs):
44+
dim = kwargs.pop('axis', dim)
4445
if dim is None and not keepdim:
4546
return execute('max', input)
4647
if core.is_tensor(dim):

0 commit comments

Comments
 (0)