Skip to content

Commit 68f1b3b

Browse files
authored
fix diffusers pipeline h class ut (#2086)
1 parent 188c0b4 commit 68f1b3b

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

mindnlp/core/_tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,8 @@ def __contains__(self, item):
681681
Tensor.bernoulli_ = ops.inplace_bernoulli
682682
StubTensor.bernoulli_ = ops.inplace_bernoulli
683683

684+
Tensor.scatter_reduce = ops.scatter_reduce
685+
StubTensor.scatter_reduce = ops.scatter_reduce
684686

685687
def _rebuild_from_type_v2(func, new_type, args, state):
686688
ret = func(*args)

mindnlp/core/ops/array.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from mindspore import ops
66
from mindspore.ops._primitive_cache import _get_cache_prim
77
from mindspore.ops.operations._grad_ops import StridedSliceGrad
8+
from mindspore.ops.auto_generate.gen_ops_prim import inplace_scatter_src_reduce_op
89

910
from ..configs import use_pyboost, ON_ORANGE_PI
1011
from .other import broadcast_tensors, finfo
@@ -220,8 +221,15 @@ def scatter_add(input, dim, index, src):
220221
return mindspore.mint.scatter_add(input, dim, index, src)
221222
return ops.tensor_scatter_elements(input, index, src, dim, 'add')
222223

224+
scatter_reduce_dict = {
225+
'sum': 'add',
226+
'amax': 'max',
227+
'amin': 'min',
228+
'mean': 'mean'
229+
}
223230
# scatter_reduce
224-
231+
def scatter_reduce(input, dim, index, src, reduce, *, include_self=True):
232+
return inplace_scatter_src_reduce_op(input.clone(), dim, index, src, scatter_reduce_dict[reduce])
225233

226234
# scatter_nd_update
227235
def scatter_nd_update(input, indices, update):
@@ -799,7 +807,7 @@ def strided_slice_update(input, begin, end, strides, update, begin_mask=0, end_m
799807
# select_scatter
800808
# slice_scatter
801809
'scatter_add',
802-
# scatter_reduce
810+
'scatter_reduce',
803811
'scatter_nd_update',
804812
'scatter_update',
805813
'split',

mindnlp/core/ops/creation.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,18 @@ def zeros(*size, dtype=None, device=None, requires_grad=False, **kwargs):
5353
size = ((),)
5454
if isinstance(size[0], (tuple, list)):
5555
size = size[0]
56+
57+
new_size = ()
58+
for s in size:
59+
if not isinstance(s, int):
60+
s = s.item()
61+
new_size += (s,)
5662
if use_pyboost() and has_zeros:
5763
# if device == 'cpu':
5864
# return mindspore.Tensor(np.zeros(size), dtype=dtype)
59-
return mindspore.mint.zeros(size, dtype=dtype)
65+
return mindspore.mint.zeros(new_size, dtype=dtype)
6066
size = tuple(size)
61-
return _zeros(size, dtype)
67+
return _zeros(new_size, dtype)
6268

6369
# zeros_like
6470
has_zeros_like = hasattr(mindspore.mint, 'zeros_like')

0 commit comments

Comments
 (0)