1
+ """fft"""
2
+ from mindspore import ops
3
+ from mindspore .ops ._primitive_cache import _get_cache_prim
4
+ from ..configs import use_pyboost
5
+ from ..ops import narrow
6
+ from ..nn import functional as F
7
+
8
+ def rfft (input , n = None , dim = - 1 , norm = "backward" ):
9
+ if use_pyboost ():
10
+ return ops .rfft (input , n , dim , norm )
11
+ if input .shape [dim ] < n :
12
+ pad_inf = (0 , n - input .shape [dim ])
13
+ pad_dims = (0 , 0 ) * (input .ndim - (dim + 1 )) + pad_inf
14
+ input = F .pad (input , pad_dims )
15
+ else :
16
+ input = narrow (input , dim , 0 , n )
17
+ _rfft = _get_cache_prim (ops .FFTWithSize )(input .ndim , False , True , norm )
18
+ return _rfft (input )
19
+
20
+ def irfft (input , n = None , dim = - 1 , norm = "backward" ):
21
+ if use_pyboost ():
22
+ return ops .irfft (input , n , dim , norm )
23
+ if input .shape [dim ] < n :
24
+ pad_inf = (0 , n - input .shape [dim ])
25
+ pad_dims = (0 , 0 ) * (input .ndim - (dim + 1 )) + pad_inf
26
+ input = pad (input , pad_dims )
27
+ else :
28
+ input = narrow (input , dim , 0 , n )
29
+ _irfft = _get_cache_prim (ops .FFTWithSize )(input .ndim , True , True , norm )
30
+ return _irfft (input )
31
+
32
+ def fftn (input , s = None , dim = None , norm = None ):
33
+ return ops .fftn (input , s , dim , norm )
34
+
35
+ def fft (input , s = None , dim = - 1 , norm = None ):
36
+ return ops .fft (input , s , dim , norm )
37
+
38
+ __all__ = ['fft' , 'fftn' , 'irfft' , 'rfft' ]
0 commit comments