Skip to content

Commit 724f735

Browse files
authored
fix transformers w class ut and diffusers omni (#2092)
1 parent df8e766 commit 724f735

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

mindnlp/core/_tensor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,27 @@ def __contains__(self, item):
711711
Tensor.var = ops.var
712712
StubTensor.var = ops.var
713713

714+
Tensor.logsumexp = ops.logsumexp
715+
StubTensor.logsumexp = ops.logsumexp
716+
717+
def __bool__(self):
718+
if self.ndim > 0:
719+
return True
720+
return bool(self._item())
721+
722+
Tensor.__bool__ = __bool__
723+
StubTensor.__bool__ = __bool__
724+
725+
def __iter__(self):
726+
if self.ndim == 0:
727+
yield self
728+
else:
729+
for i in range(len(self)):
730+
yield self[i]
731+
732+
Tensor.__iter__ = __iter__
733+
StubTensor.__iter__ = __iter__
734+
714735
def _rebuild_from_type_v2(func, new_type, args, state):
715736
ret = func(*args)
716737
return ret

mindnlp/core/ops/other.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1044,6 +1044,9 @@ def cartesian_prod(*tensors):
10441044
def detach(input):
10451045
return ops.stop_gradient(input)
10461046

1047+
def cosine_similarity(*args, **kwargs):
1048+
return core.nn.functional.cosine_similarity(*args, **kwargs)
1049+
10471050
__all__ = [
10481051
"bincount",
10491052
"broadcast_shapes",
@@ -1081,5 +1084,6 @@ def detach(input):
10811084
"histc",
10821085
"view_as_complex",
10831086
"view_as_real",
1084-
"detach"
1087+
"detach",
1088+
"cosine_similarity"
10851089
]

0 commit comments

Comments
 (0)