Skip to content

Commit 6719e5e

Browse files
authored
fix apis for d class (#2151)
1 parent 93e646a commit 6719e5e

File tree

10 files changed

+181
-36
lines changed

10 files changed

+181
-36
lines changed

mindnlp/core/_prims/ascend.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,3 +332,46 @@ def grid_sampler_3d(input, grid, mode, padding_mode, align_corners):
332332

333333
__all__.append('grid_sampler_2d')
334334
__all__.append('grid_sampler_3d')
335+
336+
def rms_norm(x, gamma, epsilon):
337+
return pyboost_inner_prim.rms_norm_impl(x, gamma, epsilon)[0]
338+
339+
__all__.append('rms_norm')
340+
341+
_complex = ops.Complex().set_device('Ascend')
342+
def view_as_complex(input):
343+
real_part, imag_part = input.tensor_split(2, -1)
344+
return _complex(real_part.squeeze(-1), imag_part.squeeze(-1))
345+
346+
__all__.append('view_as_complex')
347+
348+
imag_op = ops.Imag().set_device('Ascend')
349+
def imag(input):
350+
return imag_op(input)
351+
352+
__all__.append('imag')
353+
354+
def glu(x, axis):
355+
return pyboost_inner_prim.glu_impl(x, axis)
356+
357+
__all__.append('glu')
358+
359+
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False):
360+
ctc_loss_op = _get_cache_prim(ops.CTCLossV2)(blank=blank, reduction="none", zero_infinity=zero_infinity).set_device('Ascend')
361+
if targets.ndim == 1:
362+
targets = targets.unsqueeze(-1)
363+
loss, _ = ctc_loss_op(log_probs, targets, input_lengths, target_lengths)
364+
if zero_infinity:
365+
loss = select(isinf(loss), 0., loss)
366+
if reduction == 'sum':
367+
loss = sum_ext(loss)
368+
if reduction == 'mean':
369+
input_type = loss.dtype
370+
target_length_t = target_lengths.clip(1., None)
371+
loss = loss.astype("float32")
372+
loss = div(loss, target_length_t)
373+
loss = mean_ext(loss)
374+
loss = loss.astype(input_type)
375+
return loss
376+
377+
__all__.append('ctc_loss')

mindnlp/core/_prims/meta.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,3 +367,54 @@ def sqrt(input):
367367
return input
368368

369369
__all__.append('sqrt')
370+
371+
def normal_float_float(mean, std, size, seed, offset):
372+
out = Tensor_(shape=size, dtype=core.float32)
373+
return core.Tensor(out)
374+
375+
376+
__all__.append('normal_float_float')
377+
378+
def stack(tensors, dim):
379+
x_shape = list(tensors[0].shape)
380+
x_shape.insert(dim, len(tensors))
381+
out = Tensor_(shape=tuple(x_shape), dtype=tensors[0].dtype)
382+
return core.Tensor(out)
383+
384+
__all__.append('stack')
385+
386+
def argmax_with_value(input, dim, keepdim):
387+
out_shape = list(input.shape)
388+
if keepdim:
389+
out_shape[dim] = 1
390+
else:
391+
out_shape.pop(dim)
392+
393+
indices = Tensor_(shape=out_shape, dtype=core.int64)
394+
values = Tensor_(shape=out_shape, dtype=input.dtype)
395+
396+
return core.Tensor(indices), core.Tensor(values)
397+
398+
__all__.append('argmax_with_value')
399+
400+
def tile(input, dims):
401+
input_shape = input.shape
402+
out_shape = [input_shape[i] * dims[i] for i in range(input.ndim)]
403+
out = Tensor_(shape=tuple(out_shape), dtype=input.dtype)
404+
return core.Tensor(out)
405+
406+
__all__.append('tile')
407+
408+
def flatten_ext(input, start_dim, end_dim):
409+
input_shape = list(input.shape)
410+
if start_dim < 0:
411+
start_dim = start_dim + input.ndim
412+
if end_dim < 0:
413+
end_dim = end_dim + input.ndim
414+
415+
flatten_shape = input_shape[:start_dim] + input_shape[start_dim:end_dim+1] + input_shape[end_dim+1:]
416+
out = Tensor_(shape=tuple(flatten_shape), dtype=input.dtype)
417+
return core.Tensor(out)
418+
419+
__all__.append('flatten_ext')
420+

mindnlp/core/_prims/numpy.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,3 +766,16 @@ def repeat_interleave_tensor(input, repeats, dim, _):
766766

767767
__all__.append('repeat_interleave_tensor')
768768

769+
def greater(input, other):
770+
if not isinstance(input, numbers.Number):
771+
input = input.numpy()
772+
if not isinstance(other, numbers.Number):
773+
other = other.numpy()
774+
775+
out = input > other
776+
if not isinstance(out, np.ndarray):
777+
out = np.array(out)
778+
779+
return core.Tensor.from_numpy(out)
780+
781+
__all__.append('greater')

mindnlp/core/_tensor.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,8 @@ def __rmul__(self, other):
310310
return self.item() * other
311311
return self.__mul__(other)
312312

313+
def __abs__(self):
314+
return ops.abs(self)
313315

314316
def __imul__(self, other):
315317
return self.copy_(ops.mul(self, other))
@@ -2038,8 +2040,8 @@ def size(self, dim=None):
20382040

20392041

20402042
# Tensor.softmax
2041-
def softmax(self, dim):
2042-
return ops.softmax(self, dim)
2043+
def softmax(self, dim, dtype=None):
2044+
return ops.softmax(self, dim, dtype=dtype)
20432045

20442046
# Tensor.sort
20452047
def sort(self, dim=-1, descending=False):
@@ -2125,7 +2127,8 @@ def sub_(self, other, *, alpha=1):
21252127
subtract_ = sub_
21262128

21272129
# Tensor.sum
2128-
def sum(self, dim=None, keepdim=False, dtype=None):
2130+
def sum(self, dim=None, keepdim=False, dtype=None, **kwargs):
2131+
dim = kwargs.pop('axis', dim)
21292132
return ops.sum(self, dim, keepdim, dtype=dtype)
21302133

21312134
# Tensor.sum_to_size
@@ -2155,7 +2158,8 @@ def t_(self):
21552158
return self
21562159

21572160
# Tensor.tensor_split
2158-
2161+
def tensor_split(self, indices_or_sections, dim=0):
2162+
return ops.tensor_split(self, indices_or_sections, dim)
21592163

21602164
# Tensor.tile
21612165
def tile(self, *dims):
@@ -2438,7 +2442,7 @@ def detach(self):
24382442

24392443
# Tensor.detach_
24402444
def detach_(self):
2441-
self.requires_grad_(self)
2445+
self.requires_grad_(False)
24422446
return self
24432447

24442448
def stub_sync(self):

mindnlp/core/nn/functional.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,8 @@ def rms_norm(input, normalized_shape, weight, eps=None):
197197
if eps is None:
198198
eps = core.finfo(input.dtype).eps
199199
if weight is None:
200-
weight = core.ones(normalized_shape)
201-
return ops.rms_norm(input, weight, eps)[0]
200+
weight = core.ones(normalized_shape, dtype=input.dtype, device=input.device)
201+
return execute('rms_norm', input, weight, eps)
202202

203203
def fast_gelu(x):
204204
return ops.fast_gelu(x)
@@ -760,7 +760,6 @@ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
760760
def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
761761
if isinstance(padding, str):
762762
return execute('conv3d_padding', input, weight, bias, stride, padding, dilation, groups)
763-
print(input.device, weight.device)
764763
return execute('conv3d_ext', input, weight, bias, stride, padding, dilation, groups)
765764

766765
pad_mode = 'pad'
@@ -1577,28 +1576,13 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
15771576
return ops.fold(input, output_size, kernel_size, dilation, padding, stride)
15781577

15791578
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False):
1580-
ctc_loss_op = _get_cache_prim(ops.CTCLossV2)(blank=blank, reduction="none", zero_infinity=zero_infinity)
1581-
if targets.ndim == 1:
1582-
targets = targets.unsqueeze(-1)
1583-
loss, _ = ctc_loss_op(log_probs, targets, input_lengths, target_lengths)
1584-
if zero_infinity:
1585-
loss = ops.where(ops.isinf(loss), 0., loss)
1586-
if reduction == 'sum':
1587-
loss = loss.sum()
1588-
if reduction == 'mean':
1589-
input_type = loss.dtype
1590-
target_length_t = target_lengths.clip(1., None)
1591-
loss = loss.astype("float32")
1592-
loss = loss / target_length_t
1593-
loss = loss.mean()
1594-
loss = loss.astype(input_type)
1595-
return loss
1579+
return execute('ctc_loss', log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity)
15961580

15971581
def one_hot(tensor, num_classes=-1):
15981582
return execute('one_hot_ext', tensor, num_classes)
15991583

16001584
def pixel_shuffle(input, upscale_factor):
1601-
return ops.pixel_shuffle(input, upscale_factor)
1585+
return execute('pixel_shuffle', input, upscale_factor)
16021586

16031587
def pixel_unshuffle(input, downscale_factor):
16041588
return ops.pixel_unshuffle(input, downscale_factor)

mindnlp/core/nn/parameter.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@ def __deepcopy__(self, memodict):
2727
new_obj._device = self.device
2828
return new_obj
2929

30-
def clone(self):
31-
return copy.deepcopy(self)
32-
3330
def __parameter__(self): # only for O2
3431
"""For parse check."""
3532

mindnlp/core/ops/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""core ops like torch funcional api"""
22
from . import array, blas, comparison, pointwise, creation, random, reduction, other, \
3-
tensor, _inner, optim, inplace
3+
tensor, _inner, optim, inplace, complex
44
from .array import *
55
from .blas import *
66
from .comparison import *
@@ -14,6 +14,7 @@
1414
from ._inner import *
1515
from .optim import *
1616
from .inplace import *
17+
from .complex import *
1718

1819
def load_library(lib_path):
1920
raise ImportError('not support import any ops for now.')

mindnlp/core/ops/array.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,12 @@ def permute(input, dims):
166166
def reshape(input, *shape):
167167
if isinstance(shape[0], (tuple, list)):
168168
shape = shape[0]
169-
return execute("reshape", input, shape)
169+
new_shape = ()
170+
for s in shape:
171+
if not isinstance(s, numbers.Number):
172+
s = s.item()
173+
new_shape += (s,)
174+
return execute("reshape", input, new_shape)
170175

171176

172177
def view(input, *shape):
@@ -221,6 +226,22 @@ def split(tensor, split_size_or_sections, dim=0):
221226
)
222227
return res
223228

229+
def split_with_sizes(input, split_sizes, dim=0):
230+
assert input.dim() != 0, "split expects at least a 1-dimensional tensor"
231+
dim_size = input.size(dim)
232+
num_splits = len(split_sizes)
233+
start_idx = 0
234+
235+
splits = []
236+
for i in range(num_splits):
237+
length = split_sizes[i]
238+
assert length >= 0, f"split_with_sizes expects split_sizes have only non-negative entries, but got split_sizes={split_sizes}"
239+
splits.append(
240+
narrow(input, dim, start_idx, length)
241+
)
242+
start_idx += length
243+
244+
return splits
224245

225246
# squeeze
226247
def squeeze(input, *dim, **kwargs):
@@ -306,10 +327,27 @@ def take_along_dim(input, indices, dim=None, *, out=None):
306327
return input.view(-1).gather(0, indices.view(-1))
307328

308329
# tensor_split
309-
330+
def tensor_split(input, indices_or_sections, dim=0):
331+
if isinstance(indices_or_sections, int):
332+
# 分割成大致相等的部分
333+
dim_size = input.size(dim)
334+
if dim_size == 0:
335+
return [input] * indices_or_sections
336+
split_size = (dim_size + indices_or_sections - 1) // indices_or_sections
337+
return split(input, split_size, dim=dim)
338+
elif isinstance(indices_or_sections, (list, tuple, core.Tensor)):
339+
# 按照给定的索引分割
340+
dim_size = input.size(dim)
341+
indices = [0] + list(indices_or_sections) + [dim_size]
342+
split_sizes = [indices[i+1] - indices[i] for i in range(len(indices)-1)]
343+
return split(input, split_sizes, dim=dim)
344+
else:
345+
raise ValueError("indices_or_sections must be int or list/tuple of indices")
310346

311347
# tile
312348
def tile(input, dims):
349+
if isinstance(dims[0], (tuple, list)):
350+
dims = dims[0]
313351
return execute("tile", input, dims)
314352

315353

@@ -882,6 +920,8 @@ def getitem_np(input, slice):
882920
return execute('getitem', input, slice)
883921

884922
def setitem_np(input, slice, value):
923+
if input.device != value.device:
924+
value = value.to(input.device)
885925
return execute('setitem', input, slice, value)
886926

887927
__all__ = [
@@ -926,7 +966,7 @@ def setitem_np(input, slice, value):
926966
"swapdims",
927967
"take",
928968
"take_along_dim",
929-
# tensor_split
969+
"tensor_split",
930970
"tile",
931971
"transpose",
932972
"unbind",
@@ -940,5 +980,6 @@ def setitem_np(input, slice, value):
940980
'getitem',
941981
'setitem',
942982
'getitem_np',
943-
'setitem_np'
983+
'setitem_np',
984+
'split_with_sizes'
944985
]

mindnlp/core/ops/creation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ def frombuffer(buffer, *, dtype, count=-1, offset=0, requires_grad=False):
3030

3131

3232
# zeros
33-
def zeros(*size, out=None, dtype=None, layout=None, device=None, requires_grad=False):
33+
def zeros(*size, out=None, dtype=None, layout=None, device=None, requires_grad=False, **kwargs):
34+
size = kwargs.pop('size', size)
3435
if dtype is None:
3536
dtype = get_default_dtype()
3637
if device is None:

mindnlp/core/ops/other.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -777,9 +777,15 @@ def unflatten(x, dim, sizes):
777777

778778

779779
# view_as_real
780+
def view_as_real(input):
781+
real_part = input.real.unsqueeze(-1)
782+
imag_part = input.imag.unsqueeze(-1)
783+
return core.concat((real_part, imag_part), -1)
780784

781-
# view_as_complex
782785

786+
# view_as_complex
787+
def view_as_complex(input):
788+
return execute('view_as_complex', input)
783789

784790
# resolve_conj
785791

@@ -794,6 +800,8 @@ def masked_fill(input, mask, value):
794800
if value == float('inf'):
795801
value = finfo(input.dtype).max
796802

803+
if isinstance(value, core.Tensor) and input.device != value.device:
804+
value = value.to(input.device)
797805
return execute('masked_fill', input, mask, value)
798806

799807
class finfo:
@@ -949,5 +957,7 @@ def dyn_shape(input):
949957
"contiguous",
950958
"ravel",
951959
"dyn_shape",
952-
"diff"
960+
"diff",
961+
'view_as_complex',
962+
'view_as_real'
953963
]

0 commit comments

Comments
 (0)