@@ -172,6 +172,9 @@ template <typename OutTensorType, typename InTensorType> class matxFFTPlan_t {
172172 MATX_NVTX_START (" " , matx::MATX_NVTX_LOG_INTERNAL)
173173 FftParams_t params;
174174
175+ // Default to default stream, but caller will generally overwrite this
176+ params.stream = 0 ;
177+
175178 params.irank = i.Rank ();
176179 params.orank = o.Rank ();
177180
@@ -429,9 +432,11 @@ class matxFFTPlan1D_t : public matxFFTPlan_t<OutTensorType, InTensorType> {
429432 * Output view
430433 * @param i
431434 * Input view
435+ * @param stream
436+ * CUDA stream in which device memory allocations may be made
432437 *
433438 * */
434- matxFFTPlan1D_t (OutTensorType &o, const InTensorType &i)
439+ matxFFTPlan1D_t (OutTensorType &o, const InTensorType &i, cudaStream_t stream = 0 )
435440{
436441 MATX_NVTX_START (" " , matx::MATX_NVTX_LOG_INTERNAL)
437442
@@ -468,8 +473,7 @@ matxFFTPlan1D_t(OutTensorType &o, const InTensorType &i)
468473 &workspaceSize, this ->params_ .exec_type );
469474 MATX_ASSERT (error == CUFFT_SUCCESS, matxCufftError);
470475
471- matxAlloc ((void **)&this ->workspace_ , workspaceSize);
472- cudaMemPrefetchAsync (this ->workspace_ , workspaceSize, dev, 0 );
476+ matxAlloc ((void **)&this ->workspace_ , workspaceSize, MATX_ASYNC_DEVICE_MEMORY, stream);
473477 cufftSetWorkArea (this ->plan_ , this ->workspace_ );
474478
475479 error = cufftXtMakePlanMany (
@@ -531,6 +535,8 @@ virtual void inline Exec(OutTensorType &o, const InTensorType &i,
531535 * Output view data type
532536 * @tparam T2
533537 * Input view data type
538+ * @param stream
539+ * CUDA stream in which device memory allocations may be made
534540 */
535541template <typename OutTensorType, typename InTensorType = OutTensorType>
536542class matxFFTPlan2D_t : public matxFFTPlan_t <OutTensorType, InTensorType> {
@@ -548,7 +554,7 @@ class matxFFTPlan2D_t : public matxFFTPlan_t<OutTensorType, InTensorType> {
548554 * Input view
549555 *
550556 * */
551- matxFFTPlan2D_t (OutTensorType &o, const InTensorType &i)
557+ matxFFTPlan2D_t (OutTensorType &o, const InTensorType &i, cudaStream_t stream = 0 )
552558 {
553559 static_assert (RANK >= 2 , " 2D FFTs require a rank-2 tensor or higher" );
554560
@@ -595,8 +601,7 @@ class matxFFTPlan2D_t : public matxFFTPlan_t<OutTensorType, InTensorType> {
595601 this ->params_ .output_type , this ->params_ .batch ,
596602 &workspaceSize, this ->params_ .exec_type );
597603
598- matxAlloc ((void **)&this ->workspace_ , workspaceSize);
599- cudaMemPrefetchAsync (this ->workspace_ , workspaceSize, dev, 0 );
604+ matxAlloc ((void **)&this ->workspace_ , workspaceSize, MATX_ASYNC_DEVICE_MEMORY, stream);
600605 cufftSetWorkArea (this ->plan_ , this ->workspace_ );
601606
602607 error = cufftXtMakePlanMany (
@@ -892,7 +897,7 @@ __MATX_INLINE__ void fft_impl(OutputTensor o, const InputTensor i,
892897 // Get cache or new FFT plan if it doesn't exist
893898 auto ret = detail::cache_1d.Lookup (params);
894899 if (ret == std::nullopt ) {
895- auto tmp = new detail::matxFFTPlan1D_t<decltype (out), decltype (in)>{out, in};
900+ auto tmp = new detail::matxFFTPlan1D_t<decltype (out), decltype (in)>{out, in, stream };
896901 detail::cache_1d.Insert (params, static_cast <void *>(tmp));
897902 tmp->Forward (out, in, stream, norm);
898903 }
@@ -935,7 +940,7 @@ __MATX_INLINE__ void ifft_impl(OutputTensor o, const InputTensor i,
935940 // Get cache or new FFT plan if it doesn't exist
936941 auto ret = detail::cache_1d.Lookup (params);
937942 if (ret == std::nullopt ) {
938- auto tmp = new detail::matxFFTPlan1D_t<decltype (out), decltype (in)>{out, in};
943+ auto tmp = new detail::matxFFTPlan1D_t<decltype (out), decltype (in)>{out, in, stream };
939944 detail::cache_1d.Insert (params, static_cast <void *>(tmp));
940945 tmp->Inverse (out, in, stream, norm);
941946 }
@@ -991,7 +996,7 @@ __MATX_INLINE__ void fft2_impl(OutputTensor o, const InputTensor i,
991996 // Get cache or new FFT plan if it doesn't exist
992997 auto ret = detail::cache_2d.Lookup (params);
993998 if (ret == std::nullopt ) {
994- auto tmp = new detail::matxFFTPlan2D_t<decltype (out), decltype (in)>{out, in};
999+ auto tmp = new detail::matxFFTPlan2D_t<decltype (out), decltype (in)>{out, in, stream };
9951000 detail::cache_2d.Insert (params, static_cast <void *>(tmp));
9961001 tmp->Forward (out, in, stream);
9971002 }
@@ -1047,7 +1052,7 @@ __MATX_INLINE__ void ifft2_impl(OutputTensor o, const InputTensor i,
10471052 // Get cache or new FFT plan if it doesn't exist
10481053 auto ret = detail::cache_2d.Lookup (params);
10491054 if (ret == std::nullopt ) {
1050- auto tmp = new detail::matxFFTPlan2D_t<decltype (in), decltype (out)>{out, in};
1055+ auto tmp = new detail::matxFFTPlan2D_t<decltype (in), decltype (out)>{out, in, stream };
10511056 detail::cache_2d.Insert (params, static_cast <void *>(tmp));
10521057 tmp->Inverse (out, in, stream);
10531058 }
0 commit comments