@@ -274,12 +274,12 @@ template <typename OutTensorType, typename InTensorType> class matxCUDAFFTPlan_t
274274
275275 static inline constexpr cudaDataType GetInputType ()
276276 {
277- return GetIOType <T2>();
277+ return MatXTypeToCudaType <T2>();
278278 }
279279
280280 static inline constexpr cudaDataType GetOutputType ()
281281 {
282- return GetIOType <T1>();
282+ return MatXTypeToCudaType <T1>();
283283 }
284284
285285 static inline constexpr cudaDataType GetExecType ()
@@ -300,36 +300,6 @@ template <typename OutTensorType, typename InTensorType> class matxCUDAFFTPlan_t
300300 return CUDA_C_64F;
301301 }
302302
303- template <typename T> static inline constexpr cudaDataType GetIOType ()
304- {
305- if constexpr (std::is_same_v<T, matxFp16Complex>) {
306- return CUDA_C_16F;
307- }
308- else if constexpr (std::is_same_v<T, matxBf16Complex>) {
309- return CUDA_C_16BF;
310- }
311- if constexpr (std::is_same_v<T, matxFp16>) {
312- return CUDA_R_16F;
313- }
314- else if constexpr (std::is_same_v<T, matxBf16>) {
315- return CUDA_R_16BF;
316- }
317- if constexpr (std::is_same_v<T, cuda::std::complex <float >>) {
318- return CUDA_C_32F;
319- }
320- else if constexpr (std::is_same_v<T, cuda::std::complex <double >>) {
321- return CUDA_C_64F;
322- }
323- if constexpr (std::is_same_v<T, float >) {
324- return CUDA_R_32F;
325- }
326- else if constexpr (std::is_same_v<T, double >) {
327- return CUDA_R_64F;
328- }
329-
330- return CUDA_C_32F;
331- }
332-
333303 virtual ~matxCUDAFFTPlan_t () {
334304 if (this ->workspace_ != nullptr ) {
335305 // Pass the default stream until we allow user-deletable caches
0 commit comments