@@ -48,35 +48,34 @@ def __isinstancecheck__(self, instance):
48
48
49
49
class IntTensor (Tensor , metaclass = TypedTensorMeta ):
50
50
dtype = _dtype .int
51
- def __init__ (self , data , device = None ):
52
- super ().__init__ (data , dtype = _dtype .int )
51
+ def __init__ (self , * args , ** kwargs ):
52
+ super ().__init__ (* args , dtype = _dtype .int , ** kwargs )
53
53
54
54
class LongTensor (Tensor , metaclass = TypedTensorMeta ):
55
55
dtype = _dtype .long
56
- def __init__ (self , data , device = None ):
57
- super ().__init__ (data , dtype = _dtype .long )
56
+ def __init__ (self , * args , ** kwargs ):
57
+ super ().__init__ (* args , dtype = _dtype .long , ** kwargs )
58
58
59
59
class FloatTensor (Tensor , metaclass = TypedTensorMeta ):
60
60
dtype = _dtype .float32
61
- def __init__ (self , data , device = None ):
62
- super ().__init__ (data , dtype = _dtype .float32 )
63
-
61
+ def __init__ (self , * args , ** kwargs ):
62
+ super ().__init__ (* args , dtype = _dtype .float32 , ** kwargs )
64
63
65
64
class HalfTensor (Tensor , metaclass = TypedTensorMeta ):
66
65
dtype = _dtype .float16
67
- def __init__ (self , data , device = None ):
68
- super ().__init__ (data , dtype = _dtype .float16 )
66
+ def __init__ (self , * args , ** kwargs ):
67
+ super ().__init__ (* args , dtype = _dtype .float16 , ** kwargs )
69
68
70
69
class BFloat16Tensor (Tensor , metaclass = TypedTensorMeta ):
71
70
dtype = _dtype .float16
72
- def __init__ (self , data , device = None ):
73
- super ().__init__ (data , dtype = _dtype .bfloat16 )
74
-
71
+ def __init__ (self , * args , ** kwargs ):
72
+ super ().__init__ (* args , dtype = _dtype .bfloat16 , ** kwargs )
75
73
76
74
class BoolTensor (Tensor , metaclass = TypedTensorMeta ):
77
75
dtype = _dtype .bool
78
- def __init__ (self , data , device = None ):
79
- super ().__init__ (data , dtype = _dtype .bool )
76
+ def __init__ (self , * args , ** kwargs ):
77
+ super ().__init__ (* args , dtype = _dtype .bool , ** kwargs )
78
+
80
79
81
80
def tensor (data , * , dtype = None , device = None , requires_grad = False ):
82
81
if isinstance (data , Tensor ):
0 commit comments