|
6 | 6 | import mindspore
|
7 | 7 | from mindspore import ops, mint
|
8 | 8 | from mindspore.ops._primitive_cache import _get_cache_prim
|
| 9 | +from mindspore.ops.auto_generate import (reflection_pad_1d_op, reflection_pad_2d_op, add_layernorm_v2_op, |
| 10 | + reflection_pad_3d_op, # pylint: disable=W0611 |
| 11 | + replication_pad_1d_op, replication_pad_2d_op, replication_pad_3d_op, |
| 12 | + constant_pad_nd_op, dropout_ext_op, reverse_v2_impl, avg_pool2d_op, |
| 13 | + upsample_nearest1d_op, upsample_nearest2d_op, upsample_nearest3d_op, |
| 14 | + upsample_linear1d_op, upsample_bilinear2d_op, upsample_bicubic2d_op, |
| 15 | + upsample_trilinear3d_impl, fill_scalar_op, floor_op, nllloss_2d_op, |
| 16 | + masked_fill_op, masked_select, ones, flatten_ext, conv_transpose2d) |
| 17 | + |
| 18 | + |
9 | 19 |
|
10 | 20 | from mindnlp import core
|
11 | 21 | from ..configs import DEVICE_TARGET, ON_ORANGE_PI, use_pyboost, ON_A1
|
@@ -243,7 +253,11 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, sca
|
243 | 253 | return mint.nn.functional.embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq)
|
244 | 254 | return ops.gather(weight, input, 0)
|
245 | 255 |
|
246 |
| -def rms_norm(input, normalized_shape, weight, eps=1e-5): |
| 256 | +def rms_norm(input, normalized_shape, weight, eps=None): |
| 257 | + if eps is None: |
| 258 | + eps = core.finfo(input.dtype).eps |
| 259 | + if weight is None: |
| 260 | + weight = core.ones(normalized_shape) |
247 | 261 | return ops.rms_norm(input, weight, eps)[0]
|
248 | 262 |
|
249 | 263 | def fast_gelu(x):
|
@@ -463,7 +477,161 @@ def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
|
463 | 477 | return _layer_norm(input, weight, bias)[0]
|
464 | 478 |
|
465 | 479 | def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
|
466 |
| - return ops.interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor) |
| 480 | + if mode in ("nearest", "area", "nearest-exact"): |
| 481 | + if align_corners is not None: |
| 482 | + raise ValueError( |
| 483 | + "align_corners option can only be set with the " |
| 484 | + "interpolating modes: linear | bilinear | bicubic | trilinear" |
| 485 | + ) |
| 486 | + else: |
| 487 | + if align_corners is None: |
| 488 | + align_corners = False |
| 489 | + |
| 490 | + dim = input.dim() - 2 # Number of spatial dimensions. |
| 491 | + |
| 492 | + # Process size and scale_factor. Validate that exactly one is set. |
| 493 | + # Validate its length if it is a list, or expand it if it is a scalar. |
| 494 | + # After this block, exactly one of output_size and scale_factors will |
| 495 | + # be non-None, and it will be a list (or tuple). |
| 496 | + if size is not None and scale_factor is not None: |
| 497 | + raise ValueError("only one of size or scale_factor should be defined") |
| 498 | + elif size is not None: |
| 499 | + assert scale_factor is None |
| 500 | + scale_factors = None |
| 501 | + if isinstance(size, (list, tuple)): |
| 502 | + if len(size) != dim: |
| 503 | + raise ValueError( |
| 504 | + "Input and output must have the same number of spatial dimensions, but got " |
| 505 | + f"input with spatial dimensions of {list(input.shape[2:])} and output size of {size}. " |
| 506 | + "Please provide input tensor in (N, C, d1, d2, ...,dK) format and " |
| 507 | + "output size in (o1, o2, ...,oK) format." |
| 508 | + ) |
| 509 | + output_size = size |
| 510 | + else: |
| 511 | + output_size = [size for _ in range(dim)] |
| 512 | + elif scale_factor is not None: |
| 513 | + assert size is None |
| 514 | + output_size = None |
| 515 | + if isinstance(scale_factor, (list, tuple)): |
| 516 | + if len(scale_factor) != dim: |
| 517 | + raise ValueError( |
| 518 | + "Input and scale_factor must have the same number of spatial dimensions, but " |
| 519 | + f"got input with spatial dimensions of {list(input.shape[2:])} and " |
| 520 | + f"scale_factor of shape {scale_factor}. " |
| 521 | + "Please provide input tensor in (N, C, d1, d2, ...,dK) format and " |
| 522 | + "scale_factor in (s1, s2, ...,sK) format." |
| 523 | + ) |
| 524 | + scale_factors = scale_factor |
| 525 | + else: |
| 526 | + scale_factors = [scale_factor for _ in range(dim)] |
| 527 | + else: |
| 528 | + raise ValueError("either size or scale_factor should be defined") |
| 529 | + |
| 530 | + if ( |
| 531 | + recompute_scale_factor is not None |
| 532 | + and recompute_scale_factor |
| 533 | + and size is not None |
| 534 | + ): |
| 535 | + raise ValueError( |
| 536 | + "recompute_scale_factor is not meaningful with an explicit size." |
| 537 | + ) |
| 538 | + |
| 539 | + # "area" mode always requires an explicit size rather than scale factor. |
| 540 | + # Re-use the recompute_scale_factor code path. |
| 541 | + if mode in ["area", "bilinear"] and output_size is None: |
| 542 | + recompute_scale_factor = True |
| 543 | + |
| 544 | + if recompute_scale_factor is not None and recompute_scale_factor: |
| 545 | + # We compute output_size here, then un-set scale_factors. |
| 546 | + # The C++ code will recompute it based on the (integer) output size. |
| 547 | + assert scale_factors is not None |
| 548 | + # make scale_factor a tensor in tracing so constant doesn't get baked in |
| 549 | + output_size = [ |
| 550 | + ( |
| 551 | + math.floor( |
| 552 | + float(input.size(i + 2) * scale_factors[i]) |
| 553 | + ) |
| 554 | + ) |
| 555 | + for i in range(dim) |
| 556 | + ] |
| 557 | + scale_factors = None |
| 558 | + |
| 559 | + if antialias and not (mode in ("bilinear", "bicubic") and input.ndim == 4): |
| 560 | + raise ValueError( |
| 561 | + "Anti-alias option is restricted to bilinear and bicubic modes and requires a 4-D tensor as input" |
| 562 | + ) |
| 563 | + |
| 564 | + if input.dim() == 3 and mode == "nearest": |
| 565 | + return upsample_nearest1d_op(input, output_size, scale_factors) |
| 566 | + if input.dim() == 4 and mode == "nearest": |
| 567 | + return upsample_nearest2d_op(input, output_size, scale_factors) |
| 568 | + if input.dim() == 5 and mode == "nearest": |
| 569 | + return upsample_nearest3d_op(input, output_size, scale_factors) |
| 570 | + |
| 571 | + if input.dim() == 3 and mode == "nearest-exact": |
| 572 | + return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors) |
| 573 | + if input.dim() == 4 and mode == "nearest-exact": |
| 574 | + return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors) |
| 575 | + if input.dim() == 5 and mode == "nearest-exact": |
| 576 | + return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors) |
| 577 | + |
| 578 | + if input.dim() == 3 and mode == "area": |
| 579 | + assert output_size is not None |
| 580 | + return adaptive_avg_pool1d(input, output_size) |
| 581 | + if input.dim() == 4 and mode == "area": |
| 582 | + assert output_size is not None |
| 583 | + return adaptive_avg_pool2d(input, output_size) |
| 584 | + if input.dim() == 5 and mode == "area": |
| 585 | + assert output_size is not None |
| 586 | + return adaptive_avg_pool3d(input, output_size) |
| 587 | + |
| 588 | + if input.dim() == 3 and mode == "linear": |
| 589 | + assert align_corners is not None |
| 590 | + return upsample_linear1d_op( |
| 591 | + input, output_size, scale_factors, align_corners |
| 592 | + ) |
| 593 | + if input.dim() == 4 and mode == "bilinear": |
| 594 | + assert align_corners is not None |
| 595 | + if antialias: |
| 596 | + return torch._C._nn._upsample_bilinear2d_aa( |
| 597 | + input, output_size, align_corners, scale_factors |
| 598 | + ) |
| 599 | + return upsample_bilinear2d_op( |
| 600 | + input, output_size, scale_factors, align_corners |
| 601 | + ) |
| 602 | + if input.dim() == 5 and mode == "trilinear": |
| 603 | + assert align_corners is not None |
| 604 | + return upsample_trilinear3d_impl( |
| 605 | + input, output_size, scale_factors, align_corners |
| 606 | + ) |
| 607 | + if input.dim() == 4 and mode == "bicubic": |
| 608 | + assert align_corners is not None |
| 609 | + if antialias: |
| 610 | + return torch._C._nn._upsample_bicubic2d_aa( |
| 611 | + input, output_size, align_corners, scale_factors |
| 612 | + ) |
| 613 | + return upsample_bicubic2d_op( |
| 614 | + input, output_size, scale_factors, align_corners |
| 615 | + ) |
| 616 | + |
| 617 | + if input.dim() == 3 and mode == "bilinear": |
| 618 | + raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input") |
| 619 | + if input.dim() == 3 and mode == "trilinear": |
| 620 | + raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input") |
| 621 | + if input.dim() == 4 and mode == "linear": |
| 622 | + raise NotImplementedError("Got 4D input, but linear mode needs 3D input") |
| 623 | + if input.dim() == 4 and mode == "trilinear": |
| 624 | + raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input") |
| 625 | + if input.dim() == 5 and mode == "linear": |
| 626 | + raise NotImplementedError("Got 5D input, but linear mode needs 3D input") |
| 627 | + if input.dim() == 5 and mode == "bilinear": |
| 628 | + raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input") |
| 629 | + |
| 630 | + raise NotImplementedError( |
| 631 | + "Input Error: Only 3D, 4D and 5D input Tensors supported" |
| 632 | + f" (got {input.dim()}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact" |
| 633 | + f" (got {mode})" |
| 634 | + ) |
467 | 635 |
|
468 | 636 | def normalize(input, p=2.0, dim=1, eps=1e-6):
|
469 | 637 | r"""
|
@@ -599,8 +767,24 @@ def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
599 | 767 | raise ValueError("Requires mindspore >= 2.3.0 by default, or set into pyboost mode by calling torch.config.set_byboost(True).")
|
600 | 768 |
|
601 | 769 | def conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
602 |
| - return mint.nn.functional.conv_transpose1d(input, weight, bias, stride, padding, output_padding, groups, dilation) |
603 |
| - |
| 770 | + x_2d = input.unsqueeze(2) # (batch, in_channels, 1, L_in) |
| 771 | + |
| 772 | + # 2. 增加卷积核的高度维度 |
| 773 | + weight_2d = weight.unsqueeze(2) # (in_channels, out_channels, 1, kernel_size) |
| 774 | + |
| 775 | + # 3. 二维转置卷积 |
| 776 | + output_2d = conv_transpose2d( |
| 777 | + x_2d, |
| 778 | + weight_2d, |
| 779 | + bias, |
| 780 | + stride=(1,) + stride, |
| 781 | + padding=(0,) + padding, |
| 782 | + output_padding=(0,) + output_padding, |
| 783 | + dilation=(1,) + dilation |
| 784 | + ) # 输出形状: (batch, out_channels, 1, L_out) |
| 785 | + |
| 786 | + # 4. 移除高度维度恢复一维 |
| 787 | + return output_2d.squeeze(2) |
604 | 788 |
|
605 | 789 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
606 | 790 | return mint.nn.functional.conv_transpose2d(input, weight, bias, stride, padding, output_padding, groups, dilation)
|
@@ -1221,7 +1405,9 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
|
1221 | 1405 | return ops.fold(input, output_size, kernel_size, dilation, padding, stride)
|
1222 | 1406 |
|
1223 | 1407 | def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False):
|
1224 |
| - ctc_loss_op = _get_cache_prim(nn_ops.CTCLossV2)(blank=blank, reduction="none", zero_infinity=zero_infinity) |
| 1408 | + ctc_loss_op = _get_cache_prim(ops.CTCLossV2)(blank=blank, reduction="none", zero_infinity=zero_infinity) |
| 1409 | + if targets.ndim == 1: |
| 1410 | + targets = targets.unsqueeze(-1) |
1225 | 1411 | loss, _ = ctc_loss_op(log_probs, targets, input_lengths, target_lengths)
|
1226 | 1412 | if zero_infinity:
|
1227 | 1413 | loss = ops.where(ops.isinf(loss), 0., loss)
|
|
0 commit comments