@@ -197,8 +197,8 @@ def rms_norm(input, normalized_shape, weight, eps=None):
197
197
if eps is None :
198
198
eps = core .finfo (input .dtype ).eps
199
199
if weight is None :
200
- weight = core .ones (normalized_shape )
201
- return ops . rms_norm ( input , weight , eps )[ 0 ]
200
+ weight = core .ones (normalized_shape , dtype = input . dtype , device = input . device )
201
+ return execute ( 'rms_norm' , input , weight , eps )
202
202
203
203
def fast_gelu (x ):
204
204
return ops .fast_gelu (x )
@@ -760,7 +760,6 @@ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
760
760
def conv3d (input , weight , bias = None , stride = 1 , padding = 0 , dilation = 1 , groups = 1 ):
761
761
if isinstance (padding , str ):
762
762
return execute ('conv3d_padding' , input , weight , bias , stride , padding , dilation , groups )
763
- print (input .device , weight .device )
764
763
return execute ('conv3d_ext' , input , weight , bias , stride , padding , dilation , groups )
765
764
766
765
pad_mode = 'pad'
@@ -1577,28 +1576,13 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
1577
1576
return ops .fold (input , output_size , kernel_size , dilation , padding , stride )
1578
1577
1579
1578
def ctc_loss (log_probs , targets , input_lengths , target_lengths , blank = 0 , reduction = 'mean' , zero_infinity = False ):
1580
- ctc_loss_op = _get_cache_prim (ops .CTCLossV2 )(blank = blank , reduction = "none" , zero_infinity = zero_infinity )
1581
- if targets .ndim == 1 :
1582
- targets = targets .unsqueeze (- 1 )
1583
- loss , _ = ctc_loss_op (log_probs , targets , input_lengths , target_lengths )
1584
- if zero_infinity :
1585
- loss = ops .where (ops .isinf (loss ), 0. , loss )
1586
- if reduction == 'sum' :
1587
- loss = loss .sum ()
1588
- if reduction == 'mean' :
1589
- input_type = loss .dtype
1590
- target_length_t = target_lengths .clip (1. , None )
1591
- loss = loss .astype ("float32" )
1592
- loss = loss / target_length_t
1593
- loss = loss .mean ()
1594
- loss = loss .astype (input_type )
1595
- return loss
1579
+ return execute ('ctc_loss' , log_probs , targets , input_lengths , target_lengths , blank , reduction , zero_infinity )
1596
1580
1597
1581
def one_hot (tensor , num_classes = - 1 ):
1598
1582
return execute ('one_hot_ext' , tensor , num_classes )
1599
1583
1600
1584
def pixel_shuffle (input , upscale_factor ):
1601
- return ops . pixel_shuffle ( input , upscale_factor )
1585
+ return execute ( 'pixel_shuffle' , input , upscale_factor )
1602
1586
1603
1587
def pixel_unshuffle (input , downscale_factor ):
1604
1588
return ops .pixel_unshuffle (input , downscale_factor )
0 commit comments