Skip to content

Commit ff69c06

Browse files
authored
fix apis for s-z class (#2159)
1 parent 83e74c2 commit ff69c06

File tree

8 files changed

+36
-4145
lines changed

8 files changed

+36
-4145
lines changed

mindnlp/core/_prims/numpy.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,13 @@ def inplace_fill_scalar(input, value):
388388

389389
__all__.append('inplace_fill_scalar')
390390

391+
def inplace_fill_tensor(input, value):
392+
out = np.full_like(input.numpy(), value)
393+
numpy_to_tensor_overwrite(out, input)
394+
return input
395+
396+
__all__.append('inplace_fill_tensor')
397+
391398
def inplace_normal(input, mean, std, generator_):
392399
out = np.random.normal(mean, std, input.shape).astype(core.dtype2np[input.dtype])
393400
numpy_to_tensor_overwrite(out, input)
@@ -600,6 +607,8 @@ def randn(size, seed, offset, dtype):
600607

601608
def erfinv(input):
602609
out = scipy.special.erfinv(input)
610+
if not isinstance(out, np.ndarray):
611+
out = np.array(out)
603612
return core.Tensor.from_numpy(out)
604613

605614
__all__.append('erfinv')
@@ -910,6 +919,8 @@ def maximum(input, other):
910919

911920
def prod_ext(input, dim, keepdim, dtype):
912921
out = np.prod(input.numpy(), axis=dim, keepdims=keepdim)
922+
if not isinstance(out, np.ndarray):
923+
out = np.array(out)
913924
return core.Tensor.from_numpy(out)
914925

915926
__all__.append('prod_ext')
@@ -1136,3 +1147,9 @@ def sign(input):
11361147
return core.Tensor.from_numpy(out)
11371148

11381149
__all__.append('sign')
1150+
1151+
def log2(input):
1152+
out = np.log2(input.numpy())
1153+
return core.Tensor.from_numpy(out)
1154+
1155+
__all__.append('log2')

mindnlp/core/_tensor.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,8 @@ def __rsub__(self, other):
348348
return ops.sub(other, self)
349349

350350
def __eq__(self, other):
351+
if other is None:
352+
return False
351353
return ops.eq(self, other)
352354

353355
def __gt__(self, other):
@@ -1891,7 +1893,8 @@ def repeat_interleave(self, repeats, dim=None, output_size=None):
18911893
return ops.repeat_interleave(self, repeats, dim, output_size=output_size)
18921894

18931895
# Tensor.reshape
1894-
def reshape(self, *shape):
1896+
def reshape(self, *shape, **kwargs):
1897+
shape = kwargs.pop('shape', shape)
18951898
return ops.reshape(self, *shape)
18961899

18971900
# Tensor.reshape_as
@@ -1956,13 +1959,13 @@ def scatter_add(self, dim, index, src):
19561959

19571960

19581961
# Tensor.scatter_reduce_
1959-
def scatter_reduce_(self, dim, index, src):
1962+
def scatter_reduce_(self, dim, index, src, reduce, *, include_self=True):
19601963
return self.copy_(ops.scatter_reduce(self, dim, index, src))
19611964

19621965

19631966
# Tensor.scatter_reduce
1964-
def scatter_reduce(self, dim, index, src):
1965-
return ops.scatter_reduce(self, dim, index, src)
1967+
def scatter_reduce(self, dim, index, src, reduce, *, include_self=True):
1968+
return ops.scatter_reduce(self, dim, index, src, reduce)
19661969

19671970

19681971
# Tensor.select
@@ -2436,7 +2439,8 @@ def unsqueeze_(self, dim):
24362439

24372440

24382441
# Tensor.var
2439-
def var(self, dim=None, *, correction=1, keepdim=False):
2442+
def var(self, dim=None, *, correction=1, keepdim=False, **kwargs):
2443+
correction = int(kwargs.pop('unbiased', correction))
24402444
return ops.var(self, dim, correction=correction, keepdim=keepdim)
24412445

24422446
# Tensor.vdot

mindnlp/core/nn/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def softplus(input, beta=1, threshold=20):
5656
return execute('softplus_ext', input, beta, threshold)
5757

5858
def logsigmoid(input):
59-
return execute('logsigmoid', input)
59+
return execute('logsigmoid', input)[0]
6060

6161
def leaky_relu(input, alpha=0.2):
6262
return execute('leaky_relu_ext', input, alpha)

mindnlp/core/ops/array.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,11 @@ def scatter_add(input, dim, index, src):
271271

272272

273273
# scatter_reduce
274+
def scatter_reduce(input, dim, index, src, reduce, *, include_self=True):
275+
if reduce == 'sum':
276+
return scatter_add(input, dim, index, src)
277+
else:
278+
raise ValueError(f'do not support reduce: {reduce}')
274279

275280

276281
# split
@@ -1005,7 +1010,7 @@ def setitem_np(input, slice, value):
10051010
# select_scatter
10061011
# slice_scatter
10071012
"scatter_add",
1008-
# scatter_reduce
1013+
"scatter_reduce",
10091014
"split",
10101015
"squeeze",
10111016
"stack",

mindnlp/core/ops/random.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
# bernoulli
12-
def bernoulli(input, *, generator=None, out=None):
12+
def bernoulli(input, *, generator=None, out=None, **kwargs):
1313
if generator is None:
1414
generator = default_generator
1515
output = execute("bernoulli_ext", input, generator)

0 commit comments

Comments
 (0)