Skip to content

Commit 6180bb2

Browse files
authored
fix transformers u class ut (#2088)
1 parent 738b626 commit 6180bb2

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

mindnlp/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
contiguous_format = None
3636
preserve_format = None
3737
legacy_contiguous_format = None
38+
channels_last_3d = None
3839

3940
inf = float("inf")
4041
nan = float("nan")

mindnlp/core/_tensor.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,35 +48,34 @@ def __isinstancecheck__(self, instance):
4848

4949
class IntTensor(Tensor, metaclass=TypedTensorMeta):
5050
dtype = _dtype.int
51-
def __init__(self, data, device=None):
52-
super().__init__(data, dtype=_dtype.int)
51+
def __init__(self, *args, **kwargs):
52+
super().__init__(*args, dtype=_dtype.int, **kwargs)
5353

5454
class LongTensor(Tensor, metaclass=TypedTensorMeta):
5555
dtype = _dtype.long
56-
def __init__(self, data, device=None):
57-
super().__init__(data, dtype=_dtype.long)
56+
def __init__(self, *args, **kwargs):
57+
super().__init__(*args, dtype=_dtype.long, **kwargs)
5858

5959
class FloatTensor(Tensor, metaclass=TypedTensorMeta):
6060
dtype = _dtype.float32
61-
def __init__(self, data, device=None):
62-
super().__init__(data, dtype=_dtype.float32)
63-
61+
def __init__(self, *args, **kwargs):
62+
super().__init__(*args, dtype=_dtype.float32, **kwargs)
6463

6564
class HalfTensor(Tensor, metaclass=TypedTensorMeta):
6665
dtype = _dtype.float16
67-
def __init__(self, data, device=None):
68-
super().__init__(data, dtype=_dtype.float16)
66+
def __init__(self, *args, **kwargs):
67+
super().__init__(*args, dtype=_dtype.float16, **kwargs)
6968

7069
class BFloat16Tensor(Tensor, metaclass=TypedTensorMeta):
7170
dtype = _dtype.float16
72-
def __init__(self, data, device=None):
73-
super().__init__(data, dtype=_dtype.bfloat16)
74-
71+
def __init__(self, *args, **kwargs):
72+
super().__init__(*args, dtype=_dtype.bfloat16, **kwargs)
7573

7674
class BoolTensor(Tensor, metaclass=TypedTensorMeta):
7775
dtype = _dtype.bool
78-
def __init__(self, data, device=None):
79-
super().__init__(data, dtype=_dtype.bool)
76+
def __init__(self, *args, **kwargs):
77+
super().__init__(*args, dtype=_dtype.bool, **kwargs)
78+
8079

8180
def tensor(data, *, dtype=None, device=None, requires_grad=False):
8281
if isinstance(data, Tensor):

0 commit comments

Comments
 (0)