Skip to content

Commit 6b90e5b

Browse files
authored
fix apis for g class (#2155)
1 parent 4173f14 commit 6b90e5b

File tree

15 files changed

+349
-35
lines changed

15 files changed

+349
-35
lines changed
File renamed without changes.

mindnlp/core/_functorch/apis.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from typing import Callable
2-
32
import mindspore
43

54
def vmap(

mindnlp/core/_prims/ascend.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,3 +396,11 @@ def cross(input, other, dim=None, *, out=None):
396396
return pyboost_inner_prim.cross_impl(input, other, dim)
397397

398398
__all__.append('cross')
399+
400+
def logit(input, eps):
401+
if eps is None:
402+
eps = -1.0
403+
logit_ = _get_cache_prim(ops.Logit)(eps).set_device('Ascend')
404+
return logit_(input)
405+
406+
__all__.append('logit')

mindnlp/core/_prims/meta.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def getitem(input, slice):
6565
__all__.append('getitem')
6666

6767
def sub_ext(input, other, alpha):
68-
return input
68+
if isinstance(input, core.Tensor):
69+
return input
70+
return other
6971

7072
__all__.append('sub_ext')
7173

@@ -341,3 +343,13 @@ def reverse_v2(input, dims):
341343
return input
342344

343345
__all__.append('reverse_v2')
346+
347+
def rsqrt(input):
348+
return input
349+
350+
__all__.append('rsqrt')
351+
352+
def bitwise_xor_tensor(input, other):
353+
return input
354+
355+
__all__.append('bitwise_xor_tensor')

mindnlp/core/_prims/numpy.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -753,10 +753,13 @@ def std(input, dim, correction, keepdim):
753753
__all__.append('std')
754754

755755
def meshgrid(tensors, indexing):
756-
out = np.meshgrid(*[t.numpy() for t in tensors], indexing=indexing)
757-
if not isinstance(out, np.ndarray):
758-
out = np.array(out)
759-
return core.Tensor.from_numpy(out)
756+
outs = np.meshgrid(*[t.numpy() for t in tensors], indexing=indexing)
757+
new_outs = ()
758+
for out in outs:
759+
if not isinstance(out, np.ndarray):
760+
out = np.array(out)
761+
new_outs += (core.Tensor.from_numpy(out),)
762+
return new_outs
760763

761764
__all__.append('meshgrid')
762765

@@ -809,3 +812,58 @@ def reverse_v2(input, dims):
809812
return core.Tensor.from_numpy(out)
810813

811814
__all__.append('reverse_v2')
815+
816+
def rsqrt(input):
817+
out = np.reciprocal(np.sqrt(input.numpy()))
818+
if not isinstance(out, np.ndarray):
819+
out = np.array(out)
820+
return core.Tensor.from_numpy(out)
821+
822+
__all__.append('rsqrt')
823+
824+
def bitwise_xor_tensor(input, other):
825+
out = np.bitwise_xor(input.numpy(), other.numpy())
826+
return core.Tensor.from_numpy(out)
827+
828+
__all__.append('bitwise_xor_tensor')
829+
830+
def minimum(input, other):
831+
out = np.minimum(input.numpy(), other.numpy())
832+
return core.Tensor.from_numpy(out)
833+
834+
__all__.append('minimum')
835+
836+
def prod_ext(input, dim, keepdim, dtype):
837+
out = np.prod(input.numpy(), axis=dim, keepdims=keepdim)
838+
return core.Tensor.from_numpy(out)
839+
840+
__all__.append('prod_ext')
841+
842+
def select(condition, input, other):
843+
if not isinstance(input, numbers.Number):
844+
input = input.numpy()
845+
if not isinstance(other, numbers.Number):
846+
other = other.numpy()
847+
848+
out = np.where(condition.numpy(), input, other)
849+
return core.Tensor.from_numpy(out)
850+
851+
__all__.append('select')
852+
853+
def dense(input, weight, bias):
854+
output = np.dot(input.numpy(), weight.numpy().T)
855+
if bias is not None:
856+
output += bias
857+
return core.Tensor.from_numpy(output)
858+
859+
__all__.append('dense')
860+
861+
def dropout_ext(input, p):
862+
if p != 0:
863+
mask = (np.random.rand(*input.shape) < (1 - p))
864+
out = input.numpy() * mask / (1 - p)
865+
return core.Tensor.from_numpy(out), core.Tensor.from_numpy(mask)
866+
else:
867+
return input, None
868+
869+
__all__.append('dropout_ext')

mindnlp/core/_tensor.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def __len__(self):
196196
return self.shape[0]
197197

198198
def __repr__(self) -> str:
199-
self.data_sync(True)
199+
# self.data_sync(True)
200200
return Tensor_.__repr__(self)[:-1] + f', device={self.device})'
201201

202202
def __format__(self, format_spec):
@@ -982,8 +982,8 @@ def diagnoal(self, offset=0, dim1=0, dim2=1):
982982

983983

984984
# Tensor.div
985-
def div(self, other):
986-
return ops.div(self, other)
985+
def div(self, other, rounding_mode=None):
986+
return ops.div(self, other, rounding_mode=rounding_mode)
987987

988988
# Tensor.div_
989989
def div_(self, other):
@@ -1257,13 +1257,18 @@ def index_add_(self, dim, index, source, *, alpha=1):
12571257

12581258
# Tensor.index_add
12591259
def index_add(self, dim, index, source, *, alpha=1):
1260-
return ops.index_add(self, dim, source, alpha=alpha)
1260+
return ops.index_add(self, dim, index, source, alpha=alpha)
12611261

12621262
# Tensor.index_copy_
1263-
1263+
def index_copy_(self, dim, index, tensor2):
1264+
return self.copy_(self.index_copy(dim, index, tensor2))
12641265

12651266
# Tensor.index_copy
1266-
1267+
def index_copy(self, dim, index, tensor2):
1268+
original_values_at_index = self.index_select(dim, index)
1269+
result = self.index_add(dim, index, -original_values_at_index)
1270+
result.index_add_(dim, index, tensor2)
1271+
return result
12671272

12681273
# Tensor.index_fill_
12691274

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import contextlib
2+
13
class SDPBackend:
2-
pass
4+
MATH = 0
35

6+
@contextlib.contextmanager
47
def sdpa_kernel(*args, **kwargs):
5-
pass
8+
yield {}

mindnlp/core/nn/functional.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,9 +1590,56 @@ def pixel_shuffle(input, upscale_factor):
15901590
def pixel_unshuffle(input, downscale_factor):
15911591
return ops.pixel_unshuffle(input, downscale_factor)
15921592

1593+
def getWH(input):
1594+
"""Get [W, H] tensor from input"""
1595+
H, W = input.size()[-2:]
1596+
return core.tensor([[W, H]], dtype=core.float32, device=input.device)
1597+
1598+
def center_of(input):
1599+
"""return [(W-1)/2, (H-1)/2] tensor of input img"""
1600+
if input.dim() == 4:
1601+
H, W = input.size()[-2:]
1602+
shape = [[W, H]]
1603+
else:
1604+
D, H, W = input.size()[-3:]
1605+
shape = [[W, H, D]]
1606+
return core.tensor(shape, dtype=core.float32, device=input.device).sub_(1).div_(2)
1607+
1608+
def u(s, a: float = -0.75):
1609+
s2, s3 = s**2, s**3
1610+
l1 = (a+2)*s3 - (a+3)*s2 + 1
1611+
l2 = a*s3 - (5*a)*s2 + (8*a)*s - 4*a
1612+
return l1.where(s <= 1, l2)
1613+
1614+
def bicubic_grid_sample(input, grid, padding_mode: str = 'zeros', align_corners: bool = False):
1615+
"""bicubic_grid_sample"""
1616+
kernel_size = 4
1617+
if not align_corners:
1618+
grid = grid * getWH(input) / getWH(input).sub_(1)
1619+
center = center_of(input)
1620+
abs_loc = ((grid + 1) * center).unsqueeze(-1)
1621+
1622+
locs = abs_loc.floor() + core.tensor([-1, 0, 1, 2], device=grid.device)
1623+
1624+
loc_w, loc_h = locs.detach().flatten(0, 2).unbind(dim=-2)
1625+
loc_w = loc_w.reshape(-1, 1, kernel_size).expand(-1, kernel_size, -1)
1626+
loc_h = loc_h.reshape(-1, kernel_size, 1).expand(-1, -1, kernel_size)
1627+
loc_grid = core.stack([loc_w, loc_h], dim=-1)
1628+
loc_grid = loc_grid.view(grid.size(0), -1, 1, 2)/center - 1
1629+
1630+
selected = grid_sample(input, loc_grid.detach(), mode='nearest',
1631+
padding_mode=padding_mode, align_corners=True)
1632+
patch = selected.view(input.size()[:2]+grid.size()[1:3]+(kernel_size,)*2)
1633+
1634+
mat_r, mat_l = u(core.abs(abs_loc - locs.detach())).unbind(dim=-2)
1635+
output = core.einsum('bhwl,bchwlr,bhwr->bchw', mat_l, patch, mat_r)
1636+
return output
1637+
15931638
def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=False):
15941639
align_corners = False if align_corners is None else align_corners
15951640
if input.ndim == 4:
1641+
if mode == 'bicubic':
1642+
return bicubic_grid_sample(input, grid, padding_mode, align_corners)
15961643
return execute('grid_sampler_2d', input, grid, mode, padding_mode, align_corners)
15971644
return execute('grid_sampler_3d', input, grid, mode, padding_mode, align_corners)
15981645

mindnlp/core/ops/creation.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ def zeros(*size, out=None, dtype=None, layout=None, device=None, requires_grad=F
3939

4040
if isinstance(device, str):
4141
device = core.device(device)
42-
if isinstance(size[0], (tuple, list)):
42+
if len(size) > 0 and isinstance(size[0], (tuple, list)):
4343
size = size[0]
44-
44+
4545
new_size = ()
4646
for s in size:
4747
if not isinstance(s, int):
@@ -139,6 +139,10 @@ def linspace(start, end, steps, *, out=None, dtype=None, layout=None, device=Non
139139
if isinstance(device, str):
140140
device = core.device(device)
141141

142+
start = start.item() if isinstance(start, (core.Tensor, np.integer)) else start
143+
end = end.item() if isinstance(end, (core.Tensor, np.integer)) else end
144+
steps = steps.item() if isinstance(steps, (core.Tensor, np.integer)) else steps
145+
142146
output = execute('lin_space_ext', start, end, steps, dtype,
143147
device=device, requires_grad=requires_grad, user_created=True)
144148
if out is None:
@@ -154,6 +158,8 @@ def eye(n, m=None, *, out=None, dtype=None, layout=None, device=None, requires_g
154158
device = get_device_in_context()
155159
if dtype is None:
156160
dtype = get_default_dtype()
161+
if m is None:
162+
m = n
157163
output = execute('eye', n, m, dtype,
158164
device=device, requires_grad=requires_grad, user_created=True)
159165
if out is None:
@@ -194,8 +200,8 @@ def empty_like(input, *, dtype=None, layout=None, device=None, requires_grad=Fal
194200

195201
# full
196202
def full(size, fill_value, *, out=None, dtype=None, layout=None, device=None, requires_grad=False):
197-
if dtype is None:
198-
dtype = get_default_dtype()
203+
# if dtype is None:
204+
# dtype = get_default_dtype()
199205
if device is None:
200206
device = get_device_in_context()
201207
if device.type == 'cpu':

mindnlp/core/ops/reduction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def nansum(input, dim=None, keepdim=False, *, dtype=None):
152152

153153
# prod
154154
def prod(input, dim=None, keepdim=False, *, dtype=None):
155-
return execute('prod_ext', input, dim, keepdim,dtype)
155+
return execute('prod_ext', input, dim, keepdim, dtype)
156156

157157
# quantile
158158

0 commit comments

Comments
 (0)