diff --git a/torchvision/csrc/cpu/ROIAlign_cpu.cpp b/torchvision/csrc/cpu/ROIAlign_cpu.cpp index f421f09ebe8..3850b2833ab 100644 --- a/torchvision/csrc/cpu/ROIAlign_cpu.cpp +++ b/torchvision/csrc/cpu/ROIAlign_cpu.cpp @@ -231,8 +231,8 @@ at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, auto height = input.size(2); auto width = input.size(3); - at::Tensor output = input.type().tensor({num_rois, channels, pooled_height, pooled_width}); - + at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type()); + auto output_size = num_rois * pooled_height * pooled_width * channels; if (output.numel() == 0) diff --git a/torchvision/csrc/cpu/ROIPool_cpu.cpp b/torchvision/csrc/cpu/ROIPool_cpu.cpp index 5500a63f4ca..8ae35930533 100644 --- a/torchvision/csrc/cpu/ROIPool_cpu.cpp +++ b/torchvision/csrc/cpu/ROIPool_cpu.cpp @@ -16,8 +16,8 @@ std::tuple ROIPool_forward_cpu(const at::Tensor &input, int input_height = input.size(2); int input_width = input.size(3); - at::Tensor output = input.type().tensor({num_rois, channels, pooled_height, pooled_width}); - at::Tensor argmax = input.type().toScalarType(at::kInt).tensor({num_rois, channels, pooled_height, pooled_width}).zero_(); + at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type()); + at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type().toScalarType(at::kInt)); // define accessors for indexing auto input_a = input.accessor(); diff --git a/torchvision/csrc/cpu/nms_cpu.cpp b/torchvision/csrc/cpu/nms_cpu.cpp index aa4b9b53256..fa543b5f00c 100644 --- a/torchvision/csrc/cpu/nms_cpu.cpp +++ b/torchvision/csrc/cpu/nms_cpu.cpp @@ -1,6 +1,5 @@ #include "cpu/vision.h" - template at::Tensor nms_cpu_kernel(const at::Tensor& dets, const at::Tensor& scores, @@ -10,7 +9,7 @@ at::Tensor nms_cpu_kernel(const at::Tensor& dets, AT_ASSERTM(dets.type() == scores.type(), "dets should have the same type as scores"); if (dets.numel() == 0) - return torch::CPU(at::kLong).tensor(); + return at::empty({0}, at::device(at::kCPU).dtype(at::kLong)); auto x1_t = dets.select(1, 0).contiguous(); auto y1_t = dets.select(1, 1).contiguous(); @@ -22,7 +21,7 @@ at::Tensor nms_cpu_kernel(const at::Tensor& dets, auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); auto ndets = dets.size(0); - at::Tensor suppressed_t = at::zeros(torch::CPU(at::kByte), {ndets}); + at::Tensor suppressed_t = at::zeros({ndets}, at::device(at::kCPU).dtype(at::kByte)); auto suppressed = suppressed_t.data(); auto order = order_t.data(); @@ -66,7 +65,7 @@ at::Tensor nms_cpu(const at::Tensor& dets, const at::Tensor& scores, const float threshold) { - auto result = dets.type().tensor(); + auto result = at::empty({0}, dets.type()); AT_DISPATCH_FLOATING_TYPES(dets.type(), "nms", [&] { result = nms_cpu_kernel(dets, scores, threshold); diff --git a/torchvision/csrc/cuda/ROIAlign_cuda.cu b/torchvision/csrc/cuda/ROIAlign_cuda.cu index e8408c4ee06..9cc5ae28934 100644 --- a/torchvision/csrc/cuda/ROIAlign_cuda.cu +++ b/torchvision/csrc/cuda/ROIAlign_cuda.cu @@ -267,7 +267,7 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, auto height = input.size(2); auto width = input.size(3); - at::Tensor output = input.type().tensor({num_rois, channels, pooled_height, pooled_width}); + at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); auto output_size = num_rois * pooled_height * pooled_width * channels; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -313,7 +313,7 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); auto num_rois = rois.size(0); - at::Tensor grad_input = grad.type().tensor({batch_size, channels, height, width}).zero_(); + at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); diff --git a/torchvision/csrc/cuda/ROIPool_cuda.cu b/torchvision/csrc/cuda/ROIPool_cuda.cu index 29d9c9c9319..2ba8dc33e25 100644 --- a/torchvision/csrc/cuda/ROIPool_cuda.cu +++ b/torchvision/csrc/cuda/ROIPool_cuda.cu @@ -116,8 +116,8 @@ std::tuple ROIPool_forward_cuda(const at::Tensor& input, auto height = input.size(2); auto width = input.size(3); - at::Tensor output = input.type().tensor({num_rois, channels, pooled_height, pooled_width}); - at::Tensor argmax = input.type().toScalarType(at::kInt).tensor({num_rois, channels, pooled_height, pooled_width}).zero_(); + at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type()); + at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.type().toScalarType(at::kInt)); auto output_size = num_rois * pooled_height * pooled_width * channels; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); diff --git a/torchvision/layers/roi_pool.py b/torchvision/layers/roi_pool.py index f232d0cc160..b6d288e95e2 100644 --- a/torchvision/layers/roi_pool.py +++ b/torchvision/layers/roi_pool.py @@ -11,14 +11,14 @@ class _ROIPool(Function): @staticmethod - def forward(ctx, input, roi, output_size, spatial_scale): + def forward(ctx, input, rois, output_size, spatial_scale): ctx.output_size = _pair(output_size) ctx.spatial_scale = spatial_scale ctx.input_shape = input.size() output, argmax = _C.roi_pool_forward( - input, roi, spatial_scale, + input, rois, spatial_scale, output_size[0], output_size[1]) - ctx.save_for_backward(roi, argmax) + ctx.save_for_backward(rois, argmax) return output @staticmethod