@@ -283,13 +283,13 @@ def __setitem__(self, slices, value):
283
283
value = ops .finfo (self .dtype ).max
284
284
elif value == - float ('inf' ):
285
285
value = ops .finfo (self .dtype ).min
286
- # if isinstance(slices, tuple):
287
- # new_slices = ()
288
- # for s in slices:
289
- # if isinstance(s, range):
290
- # s = list(s)
291
- # new_slices += (s,)
292
- # slices = new_slices
286
+ if isinstance (slices , tuple ):
287
+ new_slices = ()
288
+ for s in slices :
289
+ if isinstance (s , range ):
290
+ s = list (s )
291
+ new_slices += (s ,)
292
+ slices = new_slices
293
293
if not isinstance (value , Tensor ):
294
294
value = tensor (value , dtype = self .dtype )
295
295
return origin_setitem (self , slices , value )
@@ -507,10 +507,40 @@ def __repr__(self):
507
507
Tensor .__repr__ = __repr__
508
508
StubTensor .__repr__ = _stub_method (__repr__ )
509
509
510
-
511
510
def detach_ (self ):
512
511
return ops .stop_gradient (self )
513
512
513
+ Tensor .detach_ = detach_
514
+ StubTensor .detach_ = detach_
515
+
516
+ def new_full (self , size , fill_value , * , dtype = None , device = None , requires_grad = False , layout = None , pin_memory = False ):
517
+ return ops .full (size , fill_value , dtype = dtype if dtype is not None else self .dtype )
518
+
519
+ Tensor .new_full = new_full
520
+ StubTensor .new_full = new_full
521
+
522
+ def new_zeros (self , * size , dtype = None , device = None , requires_grad = False , layout = None , pin_memory = False ):
523
+ return ops .zeros (* size , dtype = dtype if dtype is not None else self .dtype )
524
+
525
+ Tensor .new_zeros = new_zeros
526
+ StubTensor .new_zeros = new_zeros
527
+
528
+ Tensor .sum = ops .sum
529
+ StubTensor .sum = ops .sum
530
+
531
+ def new_tensor (self , data , * , dtype = None , device = None , requires_grad = False , layout = None , pin_memory = False ):
532
+ return tensor (data , dtype = dtype if dtype is not None else self .dtype )
533
+
534
+ Tensor .new_tensor = new_tensor
535
+ StubTensor .new_tensor = new_tensor
536
+
537
+ Tensor .fill_diagonal_ = ops .inplace_fill_diagonal
538
+ StubTensor .fill_diagonal_ = ops .inplace_fill_diagonal
539
+
540
+ Tensor .triu_ = ops .inplace_triu
541
+ StubTensor .triu_ = ops .inplace_triu
542
+
543
+
514
544
def _rebuild_from_type_v2 (func , new_type , args , state ):
515
545
ret = func (* args )
516
546
return ret
0 commit comments