Skip to content

Commit d316cf3

Browse files
authored
fix apis for e class (#2152)
1 parent 6719e5e commit d316cf3

File tree

10 files changed

+96
-33
lines changed

10 files changed

+96
-33
lines changed

mindnlp/core/_bind.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_autocast_dtype(device_type):
3131
def get_autocast_gpu_dtype():
3232
return AUTO_CAST_DTYE['cuda']
3333

34-
def is_autocast_enabled(device):
34+
def is_autocast_enabled(device=None):
3535
return AUTO_CAST_ENABLED
3636

3737
def set_default_dtype(dtype):

mindnlp/core/_prims/ascend.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,3 +375,24 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reducti
375375
return loss
376376

377377
__all__.append('ctc_loss')
378+
379+
def reduce_max(input, dim, keepdim):
380+
return pyboost_inner_prim.reduce_max_impl(input, dim, keepdim)
381+
382+
__all__.append('reduce_max')
383+
384+
def elu(input, alpha):
385+
return pyboost_inner_prim.elu_ext_impl(input, alpha)
386+
387+
__all__.append('elu')
388+
389+
dynamic_rnn_op = ops.DynamicRNN().set_device('Ascend')
390+
def dynamic_rnn(*args):
391+
return dynamic_rnn_op(*args)
392+
393+
__all__.append('dynamic_rnn')
394+
395+
def cross(input, other, dim=None, *, out=None):
396+
return pyboost_inner_prim.cross_impl(input, other, dim)
397+
398+
__all__.append('cross')

mindnlp/core/_prims/meta.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,3 +418,22 @@ def flatten_ext(input, start_dim, end_dim):
418418

419419
__all__.append('flatten_ext')
420420

421+
def cumsum_ext(input, dim, dtype):
422+
return input
423+
424+
__all__.append('cumsum_ext')
425+
426+
def squeeze(input, dim):
427+
input_shape = list(input.shape)
428+
if isinstance(dim, int):
429+
dim = (dim,)
430+
431+
new_shape = ()
432+
for idx, s in enumerate(input_shape):
433+
if idx not in dim and s != 1:
434+
new_shape += (s,)
435+
436+
out = Tensor_(shape=tuple(new_shape), dtype=input.dtype)
437+
return core.Tensor(out)
438+
439+
__all__.append('squeeze')

mindnlp/core/_prims/numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,3 +779,9 @@ def greater(input, other):
779779
return core.Tensor.from_numpy(out)
780780

781781
__all__.append('greater')
782+
783+
def linalg_vector_norm(input, p, dim, keepdim, dtype):
784+
out = np.linalg.norm(input.numpy(), p, dim, keepdim)
785+
return core.Tensor.from_numpy(out)
786+
787+
__all__.append('linalg_vector_norm')

mindnlp/core/_tensor.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -386,11 +386,20 @@ def new(self, *shape):
386386

387387
# Tensor.new_tensor
388388
def new_tensor(self, data, *, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False):
389-
return tensor(data, dtype=dtype if dtype is not None else self.dtype)
389+
if device is None:
390+
device = self.device
391+
if dtype is None:
392+
dtype = self.dtype
393+
394+
return tensor(data, dtype=dtype, device=device)
390395

391396
# Tensor.new_full
392397
def new_full(self, size, fill_value, *, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False):
393-
return ops.full(size, fill_value, dtype=dtype if dtype is not None else self.dtype)
398+
if device is None:
399+
device = self.device
400+
if dtype is None:
401+
dtype = self.dtype
402+
return ops.full(size, fill_value, dtype=dtype, device=device)
394403

395404
# Tensor.new_empty
396405
def new_empty(self, size, *, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False):
@@ -1058,7 +1067,8 @@ def expm1_(self):
10581067

10591068

10601069
# Tensor.expand
1061-
def expand(self, *size):
1070+
def expand(self, *size, **kwargs):
1071+
size = kwargs.pop('size', size)
10621072
if len(size) == 1:
10631073
size = size[0]
10641074
return self.broadcast_to(size)
@@ -1284,7 +1294,7 @@ def index_select(self, dim, index):
12841294

12851295
# Tensor.int
12861296
def int(self):
1287-
return self.to(mindspore.int64)
1297+
return self.to(mindspore.int32)
12881298

12891299
# Tensor.int_repr
12901300

@@ -2129,6 +2139,7 @@ def sub_(self, other, *, alpha=1):
21292139
# Tensor.sum
21302140
def sum(self, dim=None, keepdim=False, dtype=None, **kwargs):
21312141
dim = kwargs.pop('axis', dim)
2142+
keepdim = kwargs.pop('keepdims', keepdim)
21322143
return ops.sum(self, dim, keepdim, dtype=dtype)
21332144

21342145
# Tensor.sum_to_size
@@ -2551,6 +2562,9 @@ def log_softmax(self, dim):
25512562
def char(self):
25522563
return self.to(core.int8)
25532564

2565+
def cross(self, other, dim=None):
2566+
return ops.cross(self, other, dim)
2567+
25542568
@property
25552569
def is_nested(self):
25562570
return False

mindnlp/core/linalg/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def cholesky_ex(A, *, upper=False, check_errors=False, out=None):
2525

2626

2727
def norm(A, ord=None, dim=None, keepdim=False, *, out=None, dtype=None):
28-
return mint.norm(A, 2 if ord is None else ord, dim, keepdim, dtype=dtype)
28+
return core.norm(A, 2 if ord is None else ord, dim, keepdim, dtype=dtype)
2929

3030
def vector_norm(x, ord=2, dim=None, keepdim=False, *, dtype=None, out=None):
3131
return execute('linalg_vector_norm', x, ord, dim, keepdim, dtype=dtype)

mindnlp/core/nn/functional.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def relu6(input):
4747
return execute('relu6', input)
4848

4949
def elu(input, alpha=1.0):
50-
return execute('relu6', input, alpha)
50+
return execute('elu', input, alpha)
5151

5252
def glu(input, dim=-1):
5353
return execute('glu', input, dim)
@@ -59,9 +59,7 @@ def logsigmoid(input):
5959
return execute('logsigmoid', input)
6060

6161
def leaky_relu(input, alpha=0.2):
62-
if use_pyboost():
63-
return mint.nn.functional.leaky_relu(input, alpha)
64-
return ops.leaky_relu(input, alpha)
62+
return execute('leaky_relu_ext', input, alpha)
6563

6664
def prelu(input, weight):
6765
return execute('prelu', input, weight)
@@ -284,6 +282,9 @@ def _circular_pad(input_x, padding):
284282
return out
285283

286284
def pad(input, pad, mode='constant', value=None):
285+
if isinstance(pad, tuple):
286+
pad = tuple(p if isinstance(p, int) else p.item() for p in pad)
287+
287288
if input.device.type in ['cpu', 'meta'] or ON_A1:
288289
new_pad = ()
289290
for idx, pad_v in enumerate(pad):
@@ -296,6 +297,8 @@ def pad(input, pad, mode='constant', value=None):
296297
return input
297298
if mode == 'circular':
298299
return custom_circular_pad(input, pad)
300+
elif mode == 'reflect':
301+
return execute('pad_v3', input, new_pad, mode)
299302
return execute('pad_v3', input, new_pad, mode, value)
300303
out = input
301304
if (isinstance(pad, tuple) and not pad):

mindnlp/core/nn/modules/rnn.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from mindspore import ops as P
2020

2121
from mindnlp import core
22+
from mindnlp.core.executor import execute
2223
from .module import Module
2324
from .dropout import Dropout
2425
from ..parameter import Parameter
@@ -29,9 +30,9 @@
2930
__all__ = ['LSTM', 'GRU', 'RNN']
3031

3132

32-
def _init_state(shape, dtype, is_lstm):
33-
hx = ops.zeros(*shape, dtype=dtype)
34-
cx = ops.zeros(*shape, dtype=dtype)
33+
def _init_state(shape, dtype, device, is_lstm):
34+
hx = ops.zeros(*shape, dtype=dtype, device=device)
35+
cx = ops.zeros(*shape, dtype=dtype, device=device)
3536
if is_lstm:
3637
return (hx, cx)
3738
return hx
@@ -285,7 +286,7 @@ def forward(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh):
285286
w_hh = ops.cat((w_hh_i, w_hh_g, w_hh_f, w_hh_o), 0)
286287
weight = ops.cat((w_ih, w_hh), 1)
287288
if b_ih is None:
288-
bias = ops.zeros(w_ih.shape[0], dtype=w_ih.dtype)
289+
bias = ops.zeros(w_ih.shape[0], dtype=w_ih.dtype, device=w_ih.device)
289290
else:
290291
b_ih_i, b_ih_f, b_ih_g, b_ih_o = ops.chunk(b_ih, 4, 0)
291292
b_hh_i, b_hh_f, b_hh_g, b_hh_o = ops.chunk(b_hh, 4, 0)
@@ -294,7 +295,8 @@ def forward(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh):
294295
b_ih_f + b_hh_f, \
295296
b_ih_o + b_hh_o), 0)
296297

297-
outputs, h, c, _, _, _, _, _ = self.lstm(x.to(core.float16), \
298+
outputs, h, c, _, _, _, _, _ = execute('dynamic_rnn',
299+
x.to(core.float16), \
298300
ops.transpose(weight, 1, 0).to(core.float16), \
299301
bias.to(core.float16), None, \
300302
h_0[0].unsqueeze(0).to(core.float16), \
@@ -314,8 +316,8 @@ class _RNNBase(Module):
314316
'''Basic class for RNN operators'''
315317

316318
def __init__(self, mode, input_size, hidden_size, num_layers=1, bias=True,
317-
batch_first=False, dropout=0., bidirectional=False, dtype=None):
318-
factory_kwargs = {'dtype': dtype}
319+
batch_first=False, dropout=0., bidirectional=False, dtype=None, device=None):
320+
factory_kwargs = {'dtype': dtype, 'device': device}
319321
super().__init__()
320322

321323
if not 0 <= dropout < 1:
@@ -495,7 +497,7 @@ def forward(self, x, hx=None, seq_length=None):
495497
x_dtype = x.dtype
496498
if hx is None:
497499
hx = _init_state((self.num_layers * num_directions, max_batch_size, self.hidden_size), \
498-
x_dtype, self.is_lstm)
500+
x_dtype, x.device, self.is_lstm)
499501
if self.batch_first:
500502
x = ops.permute(x, (1, 0, 2))
501503
if self.bidirectional:

mindnlp/core/ops/array.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def split(tensor, split_size_or_sections, dim=0):
218218
if isinstance(split_size_or_sections, int):
219219
res = execute("split_tensor", tensor, split_size_or_sections, dim)
220220
elif isinstance(split_size_or_sections, (list, tuple)):
221+
split_size_or_sections = tuple(s.item() if isinstance(s, core.Tensor) else s for s in split_size_or_sections)
221222
res = execute("split_with_size", tensor, split_size_or_sections, dim)
222223
else:
223224
raise TypeError(
@@ -227,21 +228,7 @@ def split(tensor, split_size_or_sections, dim=0):
227228
return res
228229

229230
def split_with_sizes(input, split_sizes, dim=0):
230-
assert input.dim() != 0, "split expects at least a 1-dimensional tensor"
231-
dim_size = input.size(dim)
232-
num_splits = len(split_sizes)
233-
start_idx = 0
234-
235-
splits = []
236-
for i in range(num_splits):
237-
length = split_sizes[i]
238-
assert length >= 0, f"split_with_sizes expects split_sizes have only non-negative entries, but got split_sizes={split_sizes}"
239-
splits.append(
240-
narrow(input, dim, start_idx, length)
241-
)
242-
start_idx += length
243-
244-
return splits
231+
return execute("split_with_size", input, split_sizes, dim)
245232

246233
# squeeze
247234
def squeeze(input, *dim, **kwargs):

mindnlp/core/ops/other.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ def clone(input, *, memory_format=core.preserve_format):
8888

8989
# cumsum
9090
def cumsum(input, dim, dtype=None):
91+
if input.dtype in [core.int64, core.bool]:
92+
return execute('cumsum_ext', input.int(), dim, None).long()
93+
if dtype is not None and dtype == core.int64:
94+
return execute('cumsum_ext', input, dim, None).long()
9195
return execute('cumsum_ext', input, dim, dtype)
9296

9397
# diag
@@ -928,6 +932,12 @@ def contiguous(input):
928932
def dyn_shape(input):
929933
return execute('dyn_shape', input)
930934

935+
def cross(input, other, dim=None, *, out=None):
936+
if dim is None:
937+
dim = -65530
938+
return execute('cross', input, other, dim)
939+
940+
931941
__all__ = [
932942
"bincount",
933943
"broadcast_shapes",
@@ -936,6 +946,7 @@ def dyn_shape(input):
936946
"cdist",
937947
"clone",
938948
"contains",
949+
"cross",
939950
"cumsum",
940951
"diag",
941952
"diagonal",

0 commit comments

Comments
 (0)