Skip to content

Commit a91fe72

Browse files
t-vifmassa
authored andcommitted
Make custom ops differentiable (#1314)
* Make custom ops differentiable and replace autograd.Function. Use ops unconditionally. We may consider removing the extension functions in a follow-up. The code-path is tested by the exisitng tests for differentiability. * add scripting gradchecks tests and use intlist * fix implicit tuple conversion for gcc-5 * fix merge
1 parent cabca39 commit a91fe72

File tree

4 files changed

+160
-74
lines changed

4 files changed

+160
-74
lines changed

test/test_ops.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,12 @@ def func(input):
188188
assert gradcheck(func, (x,)), 'gradcheck failed for roi_pool CPU'
189189
assert gradcheck(func, (x.permute(0, 1, 3, 2),)), 'gradcheck failed for roi_pool CPU'
190190

191+
@torch.jit.script
192+
def script_func(input, rois):
193+
return torch.ops.torchvision.roi_pool(input, rois, 1.0, 5, 5)[0]
194+
195+
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_pool'
196+
191197
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
192198
def test_roi_pool_basic_cuda(self):
193199
device = torch.device('cuda')
@@ -274,6 +280,12 @@ def func(input):
274280
assert gradcheck(func, (x,)), 'gradcheck failed for roi_pool CUDA'
275281
assert gradcheck(func, (x.permute(0, 1, 3, 2),)), 'gradcheck failed for roi_pool CUDA'
276282

283+
@torch.jit.script
284+
def script_func(input, rois):
285+
return torch.ops.torchvision.roi_pool(input, rois, 1.0, 5, 5)[0]
286+
287+
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_pool on CUDA'
288+
277289

278290
class RoIAlignTester(unittest.TestCase):
279291
@classmethod
@@ -428,6 +440,12 @@ def func(input):
428440
assert gradcheck(func, (x,)), 'gradcheck failed for RoIAlign CPU'
429441
assert gradcheck(func, (x.transpose(2, 3),)), 'gradcheck failed for RoIAlign CPU'
430442

443+
@torch.jit.script
444+
def script_func(input, rois):
445+
return torch.ops.torchvision.roi_align(input, rois, 0.5, 5, 5, 1)[0]
446+
447+
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_align'
448+
431449
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
432450
def test_roi_align_gradient_cuda(self):
433451
"""
@@ -462,6 +480,12 @@ def func(input):
462480
assert gradcheck(func, (x,)), 'gradcheck failed for RoIAlign CUDA'
463481
assert gradcheck(func, (x.transpose(2, 3),)), 'gradcheck failed for RoIAlign CUDA'
464482

483+
@torch.jit.script
484+
def script_func(input, rois):
485+
return torch.ops.torchvision.roi_align(input, rois, 0.5, 5, 5, 1)[0]
486+
487+
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_align on CUDA'
488+
465489

466490
class NMSTester(unittest.TestCase):
467491
def reference_nms(self, boxes, scores, iou_threshold):

torchvision/csrc/custom_ops/custom_ops.cpp

Lines changed: 128 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,135 @@ PyMODINIT_FUNC PyInit__custom_ops(void) {
2525
#endif
2626
#endif
2727

28+
using torch::Tensor;
29+
using torch::autograd::AutogradContext;
30+
using torch::autograd::Variable;
31+
using torch::autograd::variable_list;
32+
33+
class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
34+
public:
35+
static variable_list forward(
36+
AutogradContext* ctx,
37+
Variable input,
38+
Variable rois,
39+
const double spatial_scale,
40+
const int64_t pooled_height,
41+
const int64_t pooled_width,
42+
const int64_t sampling_ratio) {
43+
ctx->saved_data["spatial_scale"] = spatial_scale;
44+
ctx->saved_data["pooled_height"] = pooled_height;
45+
ctx->saved_data["pooled_width"] = pooled_width;
46+
ctx->saved_data["sampling_ratio"] = sampling_ratio;
47+
ctx->saved_data["input_shape"] = input.sizes();
48+
ctx->save_for_backward({rois});
49+
auto result = ROIAlign_forward(
50+
input,
51+
rois,
52+
spatial_scale,
53+
pooled_height,
54+
pooled_width,
55+
sampling_ratio);
56+
return {result};
57+
}
58+
59+
static variable_list backward(
60+
AutogradContext* ctx,
61+
variable_list grad_output) {
62+
// Use data saved in forward
63+
auto saved = ctx->get_saved_variables();
64+
auto rois = saved[0];
65+
auto input_shape = ctx->saved_data["input_shape"].toIntList();
66+
auto grad_in = ROIAlign_backward(
67+
grad_output[0],
68+
rois,
69+
ctx->saved_data["spatial_scale"].toDouble(),
70+
ctx->saved_data["pooled_height"].toInt(),
71+
ctx->saved_data["pooled_width"].toInt(),
72+
input_shape[0],
73+
input_shape[1],
74+
input_shape[2],
75+
input_shape[3],
76+
ctx->saved_data["sampling_ratio"].toInt());
77+
return {
78+
grad_in, Variable(), Variable(), Variable(), Variable(), Variable()};
79+
}
80+
};
81+
82+
Tensor roi_align(
83+
const Tensor& input,
84+
const Tensor& rois,
85+
const double spatial_scale,
86+
const int64_t pooled_height,
87+
const int64_t pooled_width,
88+
const int64_t sampling_ratio) {
89+
return ROIAlignFunction::apply(
90+
input,
91+
rois,
92+
spatial_scale,
93+
pooled_height,
94+
pooled_width,
95+
sampling_ratio)[0];
96+
}
97+
98+
class ROIPoolFunction : public torch::autograd::Function<ROIPoolFunction> {
99+
public:
100+
static variable_list forward(
101+
AutogradContext* ctx,
102+
Variable input,
103+
Variable rois,
104+
const double spatial_scale,
105+
const int64_t pooled_height,
106+
const int64_t pooled_width) {
107+
ctx->saved_data["spatial_scale"] = spatial_scale;
108+
ctx->saved_data["pooled_height"] = pooled_height;
109+
ctx->saved_data["pooled_width"] = pooled_width;
110+
ctx->saved_data["input_shape"] = input.sizes();
111+
auto result = ROIPool_forward(
112+
input, rois, spatial_scale, pooled_height, pooled_width);
113+
auto output = std::get<0>(result);
114+
auto argmax = std::get<1>(result);
115+
ctx->save_for_backward({rois, argmax});
116+
ctx->mark_non_differentiable({argmax});
117+
return {output, argmax};
118+
}
119+
120+
static variable_list backward(
121+
AutogradContext* ctx,
122+
variable_list grad_output) {
123+
// Use data saved in forward
124+
auto saved = ctx->get_saved_variables();
125+
auto rois = saved[0];
126+
auto argmax = saved[1];
127+
auto input_shape = ctx->saved_data["input_shape"].toIntList();
128+
auto grad_in = ROIPool_backward(
129+
grad_output[0],
130+
rois,
131+
argmax,
132+
ctx->saved_data["spatial_scale"].toDouble(),
133+
ctx->saved_data["pooled_height"].toInt(),
134+
ctx->saved_data["pooled_width"].toInt(),
135+
input_shape[0],
136+
input_shape[1],
137+
input_shape[2],
138+
input_shape[3]);
139+
return {grad_in, Variable(), Variable(), Variable(), Variable()};
140+
}
141+
};
142+
143+
std::tuple<Tensor, Tensor> roi_pool(
144+
const Tensor& input,
145+
const Tensor& rois,
146+
const double spatial_scale,
147+
const int64_t pooled_height,
148+
const int64_t pooled_width) {
149+
auto result = ROIPoolFunction::apply(
150+
input, rois, spatial_scale, pooled_height, pooled_width);
151+
return std::tuple<Tensor, Tensor>(result[0], result[1]);
152+
}
153+
28154
static auto registry =
29155
torch::RegisterOperators()
30156
.op("torchvision::nms", &nms)
31157
.op("torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> Tensor",
32-
&ROIAlign_forward)
33-
.op("torchvision::roi_pool", &ROIPool_forward);
158+
&roi_align)
159+
.op("torchvision::roi_pool", &roi_pool);

torchvision/ops/roi_align.py

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,35 +10,6 @@
1010
from ._utils import convert_boxes_to_roi_format
1111

1212

13-
class _RoIAlignFunction(Function):
14-
@staticmethod
15-
def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio):
16-
ctx.save_for_backward(roi)
17-
ctx.output_size = _pair(output_size)
18-
ctx.spatial_scale = spatial_scale
19-
ctx.sampling_ratio = sampling_ratio
20-
ctx.input_shape = input.size()
21-
_C = _lazy_import()
22-
output = _C.roi_align_forward(
23-
input, roi, spatial_scale,
24-
output_size[0], output_size[1], sampling_ratio)
25-
return output
26-
27-
@staticmethod
28-
@once_differentiable
29-
def backward(ctx, grad_output):
30-
rois, = ctx.saved_tensors
31-
output_size = ctx.output_size
32-
spatial_scale = ctx.spatial_scale
33-
sampling_ratio = ctx.sampling_ratio
34-
bs, ch, h, w = ctx.input_shape
35-
_C = _lazy_import()
36-
grad_input = _C.roi_align_backward(
37-
grad_output, rois, spatial_scale,
38-
output_size[0], output_size[1], bs, ch, h, w, sampling_ratio)
39-
return grad_input, None, None, None, None
40-
41-
4213
def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
4314
"""
4415
Performs Region of Interest (RoI) Align operator described in Mask R-CNN
@@ -66,14 +37,10 @@ def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
6637
rois = boxes
6738
if not isinstance(rois, torch.Tensor):
6839
rois = convert_boxes_to_roi_format(rois)
69-
# TODO: Change this to support backwards, which we
70-
# do not currently support when JIT tracing.
71-
if torch._C._get_tracing_state():
72-
_lazy_import()
73-
return torch.ops.torchvision.roi_align(input, rois, spatial_scale,
74-
output_size[0], output_size[1],
75-
sampling_ratio)
76-
return _RoIAlignFunction.apply(input, rois, output_size, spatial_scale, sampling_ratio)
40+
_lazy_import()
41+
return torch.ops.torchvision.roi_align(input, rois, spatial_scale,
42+
output_size[0], output_size[1],
43+
sampling_ratio)
7744

7845

7946
class RoIAlign(nn.Module):

torchvision/ops/roi_pool.py

Lines changed: 4 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,33 +10,6 @@
1010
from ._utils import convert_boxes_to_roi_format
1111

1212

13-
class _RoIPoolFunction(Function):
14-
@staticmethod
15-
def forward(ctx, input, rois, output_size, spatial_scale):
16-
ctx.output_size = _pair(output_size)
17-
ctx.spatial_scale = spatial_scale
18-
ctx.input_shape = input.size()
19-
_C = _lazy_import()
20-
output, argmax = _C.roi_pool_forward(
21-
input, rois, spatial_scale,
22-
output_size[0], output_size[1])
23-
ctx.save_for_backward(rois, argmax)
24-
return output
25-
26-
@staticmethod
27-
@once_differentiable
28-
def backward(ctx, grad_output):
29-
rois, argmax = ctx.saved_tensors
30-
output_size = ctx.output_size
31-
spatial_scale = ctx.spatial_scale
32-
bs, ch, h, w = ctx.input_shape
33-
_C = _lazy_import()
34-
grad_input = _C.roi_pool_backward(
35-
grad_output, rois, argmax, spatial_scale,
36-
output_size[0], output_size[1], bs, ch, h, w)
37-
return grad_input, None, None, None
38-
39-
4013
def roi_pool(input, boxes, output_size, spatial_scale=1.0):
4114
"""
4215
Performs Region of Interest (RoI) Pool operator described in Fast R-CNN
@@ -59,14 +32,10 @@ def roi_pool(input, boxes, output_size, spatial_scale=1.0):
5932
rois = boxes
6033
if not isinstance(rois, torch.Tensor):
6134
rois = convert_boxes_to_roi_format(rois)
62-
# TODO: Change this to support backwards, which we
63-
# do not currently support when JIT tracing.
64-
if torch._C._get_tracing_state():
65-
_lazy_import()
66-
output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale,
67-
output_size[0], output_size[1])
68-
return output
69-
return _RoIPoolFunction.apply(input, rois, output_size, spatial_scale)
35+
_lazy_import()
36+
output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale,
37+
output_size[0], output_size[1])
38+
return output
7039

7140

7241
class RoIPool(nn.Module):

0 commit comments

Comments
 (0)