Skip to content

Commit 738b626

Browse files
authored
fix diffusers pipelines k class ut (#2087)
1 parent 68f1b3b commit 738b626

File tree

4 files changed

+32
-9
lines changed

4 files changed

+32
-9
lines changed

mindnlp/core/_tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,9 @@ def __contains__(self, item):
684684
Tensor.scatter_reduce = ops.scatter_reduce
685685
StubTensor.scatter_reduce = ops.scatter_reduce
686686

687+
Tensor.tril_ = ops.inplace_tril
688+
StubTensor.tril_ = ops.inplace_tril
689+
687690
def _rebuild_from_type_v2(func, new_type, args, state):
688691
ret = func(*args)
689692
return ret

mindnlp/core/nn/modules/module.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2002,6 +2002,24 @@ def bfloat16(self: T) -> T:
20022002
"""
20032003
return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t)
20042004

2005+
def to_empty(
2006+
self, *, device, recurse: bool = True
2007+
):
2008+
r"""Move the parameters and buffers to the specified device without copying storage.
2009+
2010+
Args:
2011+
device (:class:`torch.device`): The desired device of the parameters
2012+
and buffers in this module.
2013+
recurse (bool): Whether parameters and buffers of submodules should
2014+
be recursively moved to the specified device.
2015+
2016+
Returns:
2017+
Module: self
2018+
"""
2019+
return self._apply(
2020+
lambda t: core.empty_like(t, device=device), recurse=recurse
2021+
)
2022+
20052023
def float(self: T) -> T:
20062024
r"""Casts all floating point parameters and buffers to ``float`` datatype.
20072025

mindnlp/core/ops/array.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ..configs import use_pyboost, ON_ORANGE_PI
1111
from .other import broadcast_tensors, finfo
1212
from ._inner import call_ms_func
13+
from mindnlp import core
1314

1415
# adjoint
1516

@@ -221,15 +222,11 @@ def scatter_add(input, dim, index, src):
221222
return mindspore.mint.scatter_add(input, dim, index, src)
222223
return ops.tensor_scatter_elements(input, index, src, dim, 'add')
223224

224-
scatter_reduce_dict = {
225-
'sum': 'add',
226-
'amax': 'max',
227-
'amin': 'min',
228-
'mean': 'mean'
229-
}
230-
# scatter_reduce
231225
def scatter_reduce(input, dim, index, src, reduce, *, include_self=True):
232-
return inplace_scatter_src_reduce_op(input.clone(), dim, index, src, scatter_reduce_dict[reduce])
226+
if reduce == 'sum':
227+
return scatter_add(input, dim, index, src)
228+
else:
229+
raise ValueError(f'do not support reduce: {reduce}')
233230

234231
# scatter_nd_update
235232
def scatter_nd_update(input, indices, update):

mindnlp/core/ops/inplace.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ def inplace_bernoulli(self, p=0.5, *, generator=None):
185185
self.data = core.bernoulli(self, generator=generator, p=p)
186186
return self
187187

188+
def inplace_tril(self, diagonal=0):
189+
self.data = core.tril(self, diagonal)
190+
return self
191+
188192
__all__ = [
189193
'inplace_copy',
190194
'inplace_zero',
@@ -207,5 +211,6 @@ def inplace_bernoulli(self, p=0.5, *, generator=None):
207211
'inplace_neg',
208212
'inplace_exp',
209213
'inplace_sub',
210-
'inplace_bernoulli'
214+
'inplace_bernoulli',
215+
'inplace_tril'
211216
]

0 commit comments

Comments
 (0)