19
19
from mindspore import ops as P
20
20
21
21
from mindnlp import core
22
+ from mindnlp .core .executor import execute
22
23
from .module import Module
23
24
from .dropout import Dropout
24
25
from ..parameter import Parameter
29
30
__all__ = ['LSTM' , 'GRU' , 'RNN' ]
30
31
31
32
32
- def _init_state (shape , dtype , is_lstm ):
33
- hx = ops .zeros (* shape , dtype = dtype )
34
- cx = ops .zeros (* shape , dtype = dtype )
33
+ def _init_state (shape , dtype , device , is_lstm ):
34
+ hx = ops .zeros (* shape , dtype = dtype , device = device )
35
+ cx = ops .zeros (* shape , dtype = dtype , device = device )
35
36
if is_lstm :
36
37
return (hx , cx )
37
38
return hx
@@ -285,7 +286,7 @@ def forward(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh):
285
286
w_hh = ops .cat ((w_hh_i , w_hh_g , w_hh_f , w_hh_o ), 0 )
286
287
weight = ops .cat ((w_ih , w_hh ), 1 )
287
288
if b_ih is None :
288
- bias = ops .zeros (w_ih .shape [0 ], dtype = w_ih .dtype )
289
+ bias = ops .zeros (w_ih .shape [0 ], dtype = w_ih .dtype , device = w_ih . device )
289
290
else :
290
291
b_ih_i , b_ih_f , b_ih_g , b_ih_o = ops .chunk (b_ih , 4 , 0 )
291
292
b_hh_i , b_hh_f , b_hh_g , b_hh_o = ops .chunk (b_hh , 4 , 0 )
@@ -294,7 +295,8 @@ def forward(self, x, h_0, seq_length, w_ih, w_hh, b_ih, b_hh):
294
295
b_ih_f + b_hh_f , \
295
296
b_ih_o + b_hh_o ), 0 )
296
297
297
- outputs , h , c , _ , _ , _ , _ , _ = self .lstm (x .to (core .float16 ), \
298
+ outputs , h , c , _ , _ , _ , _ , _ = execute ('dynamic_rnn' ,
299
+ x .to (core .float16 ), \
298
300
ops .transpose (weight , 1 , 0 ).to (core .float16 ), \
299
301
bias .to (core .float16 ), None , \
300
302
h_0 [0 ].unsqueeze (0 ).to (core .float16 ), \
@@ -314,8 +316,8 @@ class _RNNBase(Module):
314
316
'''Basic class for RNN operators'''
315
317
316
318
def __init__ (self , mode , input_size , hidden_size , num_layers = 1 , bias = True ,
317
- batch_first = False , dropout = 0. , bidirectional = False , dtype = None ):
318
- factory_kwargs = {'dtype' : dtype }
319
+ batch_first = False , dropout = 0. , bidirectional = False , dtype = None , device = None ):
320
+ factory_kwargs = {'dtype' : dtype , 'device' : device }
319
321
super ().__init__ ()
320
322
321
323
if not 0 <= dropout < 1 :
@@ -495,7 +497,7 @@ def forward(self, x, hx=None, seq_length=None):
495
497
x_dtype = x .dtype
496
498
if hx is None :
497
499
hx = _init_state ((self .num_layers * num_directions , max_batch_size , self .hidden_size ), \
498
- x_dtype , self .is_lstm )
500
+ x_dtype , x . device , self .is_lstm )
499
501
if self .batch_first :
500
502
x = ops .permute (x , (1 , 0 , 2 ))
501
503
if self .bidirectional :
0 commit comments