Skip to content

Commit 75c4720

Browse files
authored
fix e class ut (#2071)
1 parent 8234dcb commit 75c4720

File tree

6 files changed

+65
-22
lines changed

6 files changed

+65
-22
lines changed

mindnlp/core/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,10 @@ def set_autocast_dtype(device_type, dtype):
9696
def get_autocast_dtype(device_type):
9797
return AUTO_CAST_DTYE[device_type]
9898

99+
def get_autocast_gpu_dtype():
100+
return AUTO_CAST_DTYE['cuda']
101+
102+
def is_autocast_enabled():
103+
return True
104+
99105
__version__ = 'test_version_no_value'

mindnlp/core/_tensor.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -283,13 +283,13 @@ def __setitem__(self, slices, value):
283283
value = ops.finfo(self.dtype).max
284284
elif value == -float('inf'):
285285
value = ops.finfo(self.dtype).min
286-
# if isinstance(slices, tuple):
287-
# new_slices = ()
288-
# for s in slices:
289-
# if isinstance(s, range):
290-
# s = list(s)
291-
# new_slices += (s,)
292-
# slices = new_slices
286+
if isinstance(slices, tuple):
287+
new_slices = ()
288+
for s in slices:
289+
if isinstance(s, range):
290+
s = list(s)
291+
new_slices += (s,)
292+
slices = new_slices
293293
if not isinstance(value, Tensor):
294294
value = tensor(value, dtype=self.dtype)
295295
return origin_setitem(self, slices, value)
@@ -507,10 +507,40 @@ def __repr__(self):
507507
Tensor.__repr__ = __repr__
508508
StubTensor.__repr__ = _stub_method(__repr__)
509509

510-
511510
def detach_(self):
512511
return ops.stop_gradient(self)
513512

513+
Tensor.detach_ = detach_
514+
StubTensor.detach_ = detach_
515+
516+
def new_full(self, size, fill_value, *, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False):
517+
return ops.full(size, fill_value, dtype=dtype if dtype is not None else self.dtype)
518+
519+
Tensor.new_full = new_full
520+
StubTensor.new_full = new_full
521+
522+
def new_zeros(self, *size, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False):
523+
return ops.zeros(*size, dtype=dtype if dtype is not None else self.dtype)
524+
525+
Tensor.new_zeros = new_zeros
526+
StubTensor.new_zeros = new_zeros
527+
528+
Tensor.sum = ops.sum
529+
StubTensor.sum = ops.sum
530+
531+
def new_tensor(self, data, *, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False):
532+
return tensor(data, dtype=dtype if dtype is not None else self.dtype)
533+
534+
Tensor.new_tensor = new_tensor
535+
StubTensor.new_tensor = new_tensor
536+
537+
Tensor.fill_diagonal_ = ops.inplace_fill_diagonal
538+
StubTensor.fill_diagonal_ = ops.inplace_fill_diagonal
539+
540+
Tensor.triu_ = ops.inplace_triu
541+
StubTensor.triu_ = ops.inplace_triu
542+
543+
514544
def _rebuild_from_type_v2(func, new_type, args, state):
515545
ret = func(*args)
516546
return ret

mindnlp/core/cuda/amp/autocast_mode.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,29 +26,18 @@ def __init__(
2626
dtype: core.dtype = core.float16,
2727
cache_enabled: bool = True,
2828
):
29-
if core._jit_internal.is_scripting():
30-
self._enabled = enabled
31-
self.device = "cuda"
32-
self.fast_dtype = dtype
33-
return
3429
super().__init__(
3530
"cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled
3631
)
3732

3833
def __enter__(self):
39-
if core._jit_internal.is_scripting():
40-
return self
4134
return super().__enter__()
4235

4336
# TODO: discuss a unified TorchScript-friendly API for autocast
4437
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
45-
if core._jit_internal.is_scripting():
46-
return
4738
return super().__exit__(exc_type, exc_val, exc_tb)
4839

4940
def __call__(self, func):
50-
if core._jit_internal.is_scripting():
51-
return func
5241
return super().__call__(func)
5342

5443

mindnlp/core/linalg/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@ def cholesky_ex(A, *, upper=False, check_errors=False, out=None):
2222

2323

2424
def norm(A, ord=None, dim=None, keepdim=False, *, out=None, dtype=None):
25-
return mint.norm(A, ord, dim, keepdim, dtype=dtype)
25+
return mint.norm(A, 2 if ord is None else ord, dim, keepdim, dtype=dtype)

mindnlp/core/ops/inplace.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,19 @@ def inplace_unsqueeze(input, dim=None):
118118
input.assign_value(out)
119119
return input
120120

121+
def inplace_fill_diagonal(input, fill_value, wrap=False):
122+
fill_diagnoal_ = _get_cache_prim(ops.FillDiagonal)(float(fill_value), wrap)
123+
out = fill_diagnoal_(input)
124+
input.assign_value(out)
125+
return input
126+
127+
def inplace_triu(input, diagonal=0):
128+
out = ops.triu(input, diagonal)
129+
input.assign_value(out)
130+
return input
131+
132+
133+
121134
__all__ = [
122135
'inplace_copy',
123136
'inplace_zero',
@@ -129,5 +142,7 @@ def inplace_unsqueeze(input, dim=None):
129142
'inplace_index_copy',
130143
'inplace_index_add',
131144
'inplace_squeeze',
132-
'inplace_unsqueeze'
145+
'inplace_unsqueeze',
146+
'inplace_fill_diagonal',
147+
'inplace_triu'
133148
]

mindnlp/core/ops/reduction.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,10 @@ def std_mean(input, dim=None, *, correction=1, keepdim=False):
181181

182182
# sum
183183
has_sum = hasattr(mindspore.mint, 'sum')
184-
def sum(input, dim=None, keepdim=False, *, dtype=None):
184+
def sum(input, dim=None, keepdim=False, *, dtype=None, **kwargs):
185+
keepdims = kwargs.pop('keepdims', None)
186+
if keepdims is not None:
187+
keepdim = keepdims
185188
if 0 in input.shape:
186189
return mindspore.tensor(0, dtype=dtype)
187190
if use_pyboost() and has_sum:

0 commit comments

Comments
 (0)