4
4
import mindspore
5
5
from mindspore ._c_expression import _empty_instance
6
6
from mindnlp import core
7
- from .._op_prim .cpu import legacy
7
+ from .._op_prim .gpu import legacy
8
8
9
9
try :
10
10
from mindspore ._c_expression import TensorPy as Tensor_
@@ -34,6 +34,8 @@ def fill_scalar(size, fill_value, dtype):
34
34
return legacy .cast (legacy .fill_v2 (size , mindspore .Tensor (fill_value )), dtype )
35
35
36
36
def fill_tensor (size , fill_value , dtype ):
37
+ if dtype is None :
38
+ return legacy .fill_v2 (size , mindspore .Tensor (fill_value ))
37
39
return legacy .cast (legacy .fill_v2 (size , fill_value ), dtype )
38
40
39
41
def zeros_like (input , dtype ):
@@ -123,6 +125,9 @@ def div(input, other):
123
125
return legacy .div (input , other )
124
126
125
127
def mul (input , other ):
128
+ if input .dtype == core .bool :
129
+ if isinstance (other , bool ) or (not isinstance (other , numbers .Number ) and other .dtype == core .bool ):
130
+ return bitwise_and_scalar (input , other )
126
131
return legacy .mul (input , other )
127
132
128
133
def reduce_all (input , axis , keepdims ):
@@ -253,6 +258,11 @@ def less(input, other):
253
258
return legacy .less (input , other )
254
259
255
260
def select (condition , x , y ):
261
+ if isinstance (x , numbers .Number ) or x .ndim == 0 :
262
+ x = fill_scalar (condition .shape , x , None )
263
+ if isinstance (y , numbers .Number ) or y .ndim == 0 :
264
+ y = fill_scalar (condition .shape , y , None )
265
+
256
266
return legacy .select (condition , x , y )
257
267
258
268
def round (input , decimals ):
@@ -317,16 +327,15 @@ def ones_like(input, dtype):
317
327
return legacy .ones_like (input )
318
328
319
329
def embedding (input , weight , padding_idx , max_norm , norm_type , scale_grad_by_freq ):
320
- return cast ( legacy .gather (weight , input , 0 , 0 ), weight . dtype )
330
+ return legacy .gather (weight , input , 0 , 0 )
321
331
322
332
def linspace (start , end , steps , dtype ):
323
333
start = float (start )
324
334
end = float (end )
325
335
return legacy .lin_space (mindspore .Tensor (start ), mindspore .Tensor (end ), steps )
326
336
327
337
def masked_fill (input , mask , value ):
328
- if input .dtype .is_floating_point and isinstance (value , numbers .Number ):
329
- value = float (value )
338
+ value = fill_scalar ((), value , input .dtype )
330
339
return legacy .masked_fill (input , mask , value )
331
340
332
341
def sum (input , dim , keepdim , dtype ):
@@ -388,9 +397,14 @@ def layer_norm(input, normalized_shape, weight, bias, eps=1e-5):
388
397
return legacy .layer_norm (input , weight , bias , begin_axis , begin_axis , eps )
389
398
390
399
def argmin_with_value (input , axis , keep_dims ):
400
+ if axis is None :
401
+ axis = - 1
391
402
return legacy .arg_min_with_value (input , axis , keep_dims )
392
403
393
404
def argmax_with_value (input , axis , keep_dims ):
405
+ if axis is None :
406
+ axis = - 1
407
+
394
408
return legacy .arg_max_with_value (input , axis , keep_dims )
395
409
396
410
def silu (input ):
@@ -425,9 +439,13 @@ def eye(n, m, dtype):
425
439
return legacy .eye (n , m , dtype )
426
440
427
441
def argmax (input , axis , keep_dims ):
442
+ if axis is None :
443
+ axis = - 1
428
444
return legacy .arg_max_with_value (input , axis , keep_dims )[0 ]
429
445
430
446
def argmin (input , axis , keep_dims ):
447
+ if axis is None :
448
+ axis = - 1
431
449
return legacy .arg_min_with_value (input , axis , keep_dims )[0 ]
432
450
433
451
def exp (input ):
@@ -489,18 +507,7 @@ def scatter(input, dim, index, src):
489
507
return legacy .tensor_scatter_elements (input , index , src , dim , "none" )
490
508
491
509
def batch_norm (input , weight , bias , running_mean = None , runnning_var = None , training = False , momentum = 0.1 , epsilon = 1e-5 ):
492
- input_ndim = input .ndim
493
- if input_ndim == 2 :
494
- return legacy .batch_norm (input , weight , bias , running_mean , runnning_var , training , epsilon , momentum , 'NCHW' )
495
- else :
496
- input = transpose_view (input , 1 , - 1 )
497
- input_shape = input .shape
498
- input = reshape (input , (- 1 , input .shape [- 1 ]))
499
- outs = legacy .batch_norm (input , weight , bias , running_mean , runnning_var , training , epsilon , momentum , 'NCHW' )
500
- out = reshape (outs [0 ], (* input_shape [:- 1 ], - 1 ))
501
- out = transpose_view (out , 1 , - 1 )
502
-
503
- return out , outs [1 ], outs [2 ]
510
+ return legacy .batch_norm (input , weight , bias , running_mean , runnning_var , training , epsilon , momentum , 'NCHW' )
504
511
505
512
def tanh (input ):
506
513
return legacy .tanh (input )
@@ -797,25 +804,22 @@ def max_pool2d(input, kernel_size, stride=1, padding=0, dilation=1, ceil_mode=Fa
797
804
return out
798
805
799
806
def baddbmm (input , batch1 , batch2 , alpha = 1 , beta = 1 ):
800
- return add (mul (beta , input ), mul (alpha , bmm (batch1 , batch2 )))
807
+ return add (mul (input , beta ), mul (bmm (batch1 , batch2 ), alpha ))
801
808
802
809
def softplus (input , beta = 1 , threshold = 20 ):
803
810
return legacy .softplus (input )
804
811
805
812
def gather_nd (input , indices ):
806
813
return legacy .gather_nd (input , indices )
807
814
808
- def unique_consecutive (input , return_inverse , return_counts , dim ):
809
- return legacy .unique_consecutive (input , return_inverse , return_counts , dim )
810
-
811
815
def meshgrid (input , lambd ):
812
816
return legacy .meshgrid (input , lambd )
813
817
814
818
def addcmul (input , tensor1 , tensor2 , value = 1.0 ):
815
819
return legacy .addcmul (input , tensor1 , tensor2 , mindspore .Tensor (value ))
816
820
817
821
def addmm (input , mat1 , mat2 , alpha = 1.0 , beta = 1.0 ):
818
- return add (mul (beta , input ), mul (alpha , bmm (mat1 , mat2 )))
822
+ return add (mul (input , beta ), mul (bmm (mat1 , mat2 ), alpha ))
819
823
820
824
def im2col (input , kernel_size , dilation = 1 , padding = 0 , stride = 1 ):
821
825
out = legacy .im2_col (input , kernel_size , stride , dilation , padding )
@@ -1101,6 +1105,8 @@ def bernoulli(input, generator):
1101
1105
return legacy .bernoulli (input , seed , offset )
1102
1106
1103
1107
def arange (start , end , step , dtype ):
1108
+ if dtype is not None :
1109
+ return cast (legacy .range (start , end , step , 100000 ), dtype )
1104
1110
return legacy .range (start , end , step , 100000 )
1105
1111
1106
1112
def inplace_fill_scalar (input , value ):
@@ -1121,3 +1127,13 @@ def inplace_uniform(input, from_, to_, generator_):
1121
1127
mindspore .tensor (from_ , dtype = mindspore .int32 ),
1122
1128
mindspore .tensor (to_ , dtype = mindspore .int32 ), 0 , 0 )
1123
1129
return input .assign_value (value )
1130
+
1131
+ def right_shift (input , other ):
1132
+ return legacy .right_shift (input , other )
1133
+
1134
+ def inplace_fill_tensor (input , value ):
1135
+ input .assign_value (fill_tensor (input .shape , value , None ))
1136
+ return input
1137
+
1138
+ def search_sorted (sorted_sequence , values , sorter , dtype , right ):
1139
+ return legacy .search_sorted (sorted_sequence , values , sorter , dtype , right )
0 commit comments