Skip to content

Commit 8234dcb

Browse files
authored
fix d class ut (#2070)
1 parent 3f85596 commit 8234dcb

File tree

13 files changed

+384
-22
lines changed

13 files changed

+384
-22
lines changed

mindnlp/core/_tensor.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from mindspore.common.tensor import _TensorMeta
77
from mindspore._c_expression.typing import Type
88
try:
9-
from mindspore.common._stub_tensor import StubTensor
9+
from mindspore.common._stub_tensor import StubTensor, _stub_method
1010
except:
1111
class StubTensor: pass
1212

@@ -17,7 +17,7 @@ class StubTensor: pass
1717

1818
from . import ops, _dtype
1919
from ._dtype import dtype2np
20-
from ._bind import get_default_device, device_
20+
from ._bind import get_default_device, device_, get_default_dtype
2121
from .configs import use_pyboost, ON_A1
2222
from .storage import UntypedStorage
2323
from ._utils import _rebuild_tensor_v2
@@ -98,6 +98,16 @@ def is_tensor(x):
9898
return isinstance(x, Tensor)
9999

100100
def enable_mindspore_patch():
101+
old_init = Tensor.__init__
102+
def __init__(self, *args, **kwargs):
103+
if len(args) > 1 and all([isinstance(arg, int) for arg in args]):
104+
tensor = Tensor_(shape=args, dtype=get_default_dtype())
105+
old_init(self, tensor, internal=True)
106+
else:
107+
old_init(self, *args, **kwargs)
108+
109+
Tensor.__init__ = __init__
110+
101111
def __reduce_ex__(self, protocol):
102112
if isinstance(self, StubTensor):
103113
data = Tensor_(self.stub_sync())
@@ -280,6 +290,8 @@ def __setitem__(self, slices, value):
280290
# s = list(s)
281291
# new_slices += (s,)
282292
# slices = new_slices
293+
if not isinstance(value, Tensor):
294+
value = tensor(value, dtype=self.dtype)
283295
return origin_setitem(self, slices, value)
284296

285297
Tensor.__setitem__ = __setitem__
@@ -469,6 +481,36 @@ def pin_memory(self, *args, **kwargs):
469481
Tensor.pin_memory = pin_memory
470482
StubTensor.pin_memory = pin_memory
471483

484+
def __deepcopy__(self, memodict):
485+
new_obj = Tensor(self)
486+
return new_obj
487+
488+
Tensor.__deepcopy__ = __deepcopy__
489+
StubTensor.__deepcopy__ = __deepcopy__
490+
491+
def asnumpy(self):
492+
return Tensor_.asnumpy(self)
493+
494+
Tensor.asnumpy = asnumpy
495+
StubTensor.asnumpy = _stub_method(asnumpy)
496+
497+
def backward(self, *args, **kwargs):
498+
pass
499+
500+
Tensor.backward = backward
501+
StubTensor.backward = backward
502+
503+
def __repr__(self):
504+
Tensor_.data_sync(self, True)
505+
return Tensor_.__repr__(self)
506+
507+
Tensor.__repr__ = __repr__
508+
StubTensor.__repr__ = _stub_method(__repr__)
509+
510+
511+
def detach_(self):
512+
return ops.stop_gradient(self)
513+
472514
def _rebuild_from_type_v2(func, new_type, args, state):
473515
ret = func(*args)
474516
return ret

mindnlp/core/backends/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from . import cuda, mps
1+
from . import cuda, mps, cudnn
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from contextlib import contextmanager
2+
3+
@contextmanager
4+
def flags(
5+
enabled=False,
6+
benchmark=False,
7+
benchmark_limit=10,
8+
deterministic=False,
9+
allow_tf32=True,
10+
fp32_precision="none",
11+
):
12+
try:
13+
yield
14+
finally:
15+
pass

mindnlp/core/nn/functional.py

Lines changed: 191 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@
66
import mindspore
77
from mindspore import ops, mint
88
from mindspore.ops._primitive_cache import _get_cache_prim
9+
from mindspore.ops.auto_generate import (reflection_pad_1d_op, reflection_pad_2d_op, add_layernorm_v2_op,
10+
reflection_pad_3d_op, # pylint: disable=W0611
11+
replication_pad_1d_op, replication_pad_2d_op, replication_pad_3d_op,
12+
constant_pad_nd_op, dropout_ext_op, reverse_v2_impl, avg_pool2d_op,
13+
upsample_nearest1d_op, upsample_nearest2d_op, upsample_nearest3d_op,
14+
upsample_linear1d_op, upsample_bilinear2d_op, upsample_bicubic2d_op,
15+
upsample_trilinear3d_impl, fill_scalar_op, floor_op, nllloss_2d_op,
16+
masked_fill_op, masked_select, ones, flatten_ext, conv_transpose2d)
17+
18+
919

1020
from mindnlp import core
1121
from ..configs import DEVICE_TARGET, ON_ORANGE_PI, use_pyboost, ON_A1
@@ -243,7 +253,11 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, sca
243253
return mint.nn.functional.embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq)
244254
return ops.gather(weight, input, 0)
245255

246-
def rms_norm(input, normalized_shape, weight, eps=1e-5):
256+
def rms_norm(input, normalized_shape, weight, eps=None):
257+
if eps is None:
258+
eps = core.finfo(input.dtype).eps
259+
if weight is None:
260+
weight = core.ones(normalized_shape)
247261
return ops.rms_norm(input, weight, eps)[0]
248262

249263
def fast_gelu(x):
@@ -463,7 +477,161 @@ def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
463477
return _layer_norm(input, weight, bias)[0]
464478

465479
def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
466-
return ops.interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor)
480+
if mode in ("nearest", "area", "nearest-exact"):
481+
if align_corners is not None:
482+
raise ValueError(
483+
"align_corners option can only be set with the "
484+
"interpolating modes: linear | bilinear | bicubic | trilinear"
485+
)
486+
else:
487+
if align_corners is None:
488+
align_corners = False
489+
490+
dim = input.dim() - 2 # Number of spatial dimensions.
491+
492+
# Process size and scale_factor. Validate that exactly one is set.
493+
# Validate its length if it is a list, or expand it if it is a scalar.
494+
# After this block, exactly one of output_size and scale_factors will
495+
# be non-None, and it will be a list (or tuple).
496+
if size is not None and scale_factor is not None:
497+
raise ValueError("only one of size or scale_factor should be defined")
498+
elif size is not None:
499+
assert scale_factor is None
500+
scale_factors = None
501+
if isinstance(size, (list, tuple)):
502+
if len(size) != dim:
503+
raise ValueError(
504+
"Input and output must have the same number of spatial dimensions, but got "
505+
f"input with spatial dimensions of {list(input.shape[2:])} and output size of {size}. "
506+
"Please provide input tensor in (N, C, d1, d2, ...,dK) format and "
507+
"output size in (o1, o2, ...,oK) format."
508+
)
509+
output_size = size
510+
else:
511+
output_size = [size for _ in range(dim)]
512+
elif scale_factor is not None:
513+
assert size is None
514+
output_size = None
515+
if isinstance(scale_factor, (list, tuple)):
516+
if len(scale_factor) != dim:
517+
raise ValueError(
518+
"Input and scale_factor must have the same number of spatial dimensions, but "
519+
f"got input with spatial dimensions of {list(input.shape[2:])} and "
520+
f"scale_factor of shape {scale_factor}. "
521+
"Please provide input tensor in (N, C, d1, d2, ...,dK) format and "
522+
"scale_factor in (s1, s2, ...,sK) format."
523+
)
524+
scale_factors = scale_factor
525+
else:
526+
scale_factors = [scale_factor for _ in range(dim)]
527+
else:
528+
raise ValueError("either size or scale_factor should be defined")
529+
530+
if (
531+
recompute_scale_factor is not None
532+
and recompute_scale_factor
533+
and size is not None
534+
):
535+
raise ValueError(
536+
"recompute_scale_factor is not meaningful with an explicit size."
537+
)
538+
539+
# "area" mode always requires an explicit size rather than scale factor.
540+
# Re-use the recompute_scale_factor code path.
541+
if mode in ["area", "bilinear"] and output_size is None:
542+
recompute_scale_factor = True
543+
544+
if recompute_scale_factor is not None and recompute_scale_factor:
545+
# We compute output_size here, then un-set scale_factors.
546+
# The C++ code will recompute it based on the (integer) output size.
547+
assert scale_factors is not None
548+
# make scale_factor a tensor in tracing so constant doesn't get baked in
549+
output_size = [
550+
(
551+
math.floor(
552+
float(input.size(i + 2) * scale_factors[i])
553+
)
554+
)
555+
for i in range(dim)
556+
]
557+
scale_factors = None
558+
559+
if antialias and not (mode in ("bilinear", "bicubic") and input.ndim == 4):
560+
raise ValueError(
561+
"Anti-alias option is restricted to bilinear and bicubic modes and requires a 4-D tensor as input"
562+
)
563+
564+
if input.dim() == 3 and mode == "nearest":
565+
return upsample_nearest1d_op(input, output_size, scale_factors)
566+
if input.dim() == 4 and mode == "nearest":
567+
return upsample_nearest2d_op(input, output_size, scale_factors)
568+
if input.dim() == 5 and mode == "nearest":
569+
return upsample_nearest3d_op(input, output_size, scale_factors)
570+
571+
if input.dim() == 3 and mode == "nearest-exact":
572+
return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors)
573+
if input.dim() == 4 and mode == "nearest-exact":
574+
return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors)
575+
if input.dim() == 5 and mode == "nearest-exact":
576+
return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors)
577+
578+
if input.dim() == 3 and mode == "area":
579+
assert output_size is not None
580+
return adaptive_avg_pool1d(input, output_size)
581+
if input.dim() == 4 and mode == "area":
582+
assert output_size is not None
583+
return adaptive_avg_pool2d(input, output_size)
584+
if input.dim() == 5 and mode == "area":
585+
assert output_size is not None
586+
return adaptive_avg_pool3d(input, output_size)
587+
588+
if input.dim() == 3 and mode == "linear":
589+
assert align_corners is not None
590+
return upsample_linear1d_op(
591+
input, output_size, scale_factors, align_corners
592+
)
593+
if input.dim() == 4 and mode == "bilinear":
594+
assert align_corners is not None
595+
if antialias:
596+
return torch._C._nn._upsample_bilinear2d_aa(
597+
input, output_size, align_corners, scale_factors
598+
)
599+
return upsample_bilinear2d_op(
600+
input, output_size, scale_factors, align_corners
601+
)
602+
if input.dim() == 5 and mode == "trilinear":
603+
assert align_corners is not None
604+
return upsample_trilinear3d_impl(
605+
input, output_size, scale_factors, align_corners
606+
)
607+
if input.dim() == 4 and mode == "bicubic":
608+
assert align_corners is not None
609+
if antialias:
610+
return torch._C._nn._upsample_bicubic2d_aa(
611+
input, output_size, align_corners, scale_factors
612+
)
613+
return upsample_bicubic2d_op(
614+
input, output_size, scale_factors, align_corners
615+
)
616+
617+
if input.dim() == 3 and mode == "bilinear":
618+
raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input")
619+
if input.dim() == 3 and mode == "trilinear":
620+
raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input")
621+
if input.dim() == 4 and mode == "linear":
622+
raise NotImplementedError("Got 4D input, but linear mode needs 3D input")
623+
if input.dim() == 4 and mode == "trilinear":
624+
raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input")
625+
if input.dim() == 5 and mode == "linear":
626+
raise NotImplementedError("Got 5D input, but linear mode needs 3D input")
627+
if input.dim() == 5 and mode == "bilinear":
628+
raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input")
629+
630+
raise NotImplementedError(
631+
"Input Error: Only 3D, 4D and 5D input Tensors supported"
632+
f" (got {input.dim()}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact"
633+
f" (got {mode})"
634+
)
467635

468636
def normalize(input, p=2.0, dim=1, eps=1e-6):
469637
r"""
@@ -599,8 +767,24 @@ def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
599767
raise ValueError("Requires mindspore >= 2.3.0 by default, or set into pyboost mode by calling torch.config.set_byboost(True).")
600768

601769
def conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
602-
return mint.nn.functional.conv_transpose1d(input, weight, bias, stride, padding, output_padding, groups, dilation)
603-
770+
x_2d = input.unsqueeze(2) # (batch, in_channels, 1, L_in)
771+
772+
# 2. 增加卷积核的高度维度
773+
weight_2d = weight.unsqueeze(2) # (in_channels, out_channels, 1, kernel_size)
774+
775+
# 3. 二维转置卷积
776+
output_2d = conv_transpose2d(
777+
x_2d,
778+
weight_2d,
779+
bias,
780+
stride=(1,) + stride,
781+
padding=(0,) + padding,
782+
output_padding=(0,) + output_padding,
783+
dilation=(1,) + dilation
784+
) # 输出形状: (batch, out_channels, 1, L_out)
785+
786+
# 4. 移除高度维度恢复一维
787+
return output_2d.squeeze(2)
604788

605789
def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
606790
return mint.nn.functional.conv_transpose2d(input, weight, bias, stride, padding, output_padding, groups, dilation)
@@ -1221,7 +1405,9 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
12211405
return ops.fold(input, output_size, kernel_size, dilation, padding, stride)
12221406

12231407
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False):
1224-
ctc_loss_op = _get_cache_prim(nn_ops.CTCLossV2)(blank=blank, reduction="none", zero_infinity=zero_infinity)
1408+
ctc_loss_op = _get_cache_prim(ops.CTCLossV2)(blank=blank, reduction="none", zero_infinity=zero_infinity)
1409+
if targets.ndim == 1:
1410+
targets = targets.unsqueeze(-1)
12251411
loss, _ = ctc_loss_op(log_probs, targets, input_lengths, target_lengths)
12261412
if zero_infinity:
12271413
loss = ops.where(ops.isinf(loss), 0., loss)

mindnlp/core/nn/modules/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .container import ModuleList, ParameterList, Sequential, ParameterDict, ModuleDict
44
from .linear import Linear, Identity
55
from .sparse import Embedding
6-
from .normalization import LayerNorm, GroupNorm
6+
from .normalization import LayerNorm, GroupNorm, RMSNorm
77
from .dropout import Dropout, Dropout2d
88
from .activation import *
99
from .conv import Conv3d, Conv2d, Conv1d, ConvTranspose2d, ConvTranspose1d

mindnlp/core/nn/modules/batchnorm.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,14 @@ def __init__(
4545
self.register_buffer('running_var', ops.ones(num_features,))
4646
self.running_mean: Optional[Tensor]
4747
self.running_var: Optional[Tensor]
48-
self.register_buffer('num_batches_tracked',
49-
Tensor(0, dtype=core.int64))
48+
self.register_buffer(
49+
"num_batches_tracked",
50+
core.tensor(
51+
0,
52+
dtype=core.long,
53+
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
54+
),
55+
)
5056
self.num_batches_tracked: Optional[Tensor]
5157
else:
5258
self.register_buffer("running_mean", None)

0 commit comments

Comments
 (0)