-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add MPS kernels #7643
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MPS kernels #7643
Changes from all commits
da9c2de
7f0d4ce
c7c43dc
ccde29c
c930e54
3305cc1
0f8d2c3
e157c7c
160d5b5
40ea525
2c20036
a427c2a
8036dc2
0ae9124
1d21cfc
3018b25
d609da4
256bd56
990685f
5dce2d7
40ebde5
8e4d868
24109d4
efbb52e
b36cafa
fad54f6
66a00fc
70f3906
b1cf619
3f82ee4
108bc15
c825c53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
import torch | ||
import torch.fx | ||
import torch.nn.functional as F | ||
from common_utils import assert_equal, cpu_and_cuda, needs_cuda | ||
from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps | ||
from PIL import Image | ||
from torch import nn, Tensor | ||
from torch.autograd import gradcheck | ||
|
@@ -96,12 +96,33 @@ def forward(self, imgs: Tensor, boxes: List[Tensor]) -> Tensor: | |
|
||
class RoIOpTester(ABC): | ||
dtype = torch.float64 | ||
mps_dtype = torch.float32 | ||
mps_backward_atol = 2e-2 | ||
|
||
@pytest.mark.parametrize("device", cpu_and_cuda()) | ||
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) | ||
@pytest.mark.parametrize("contiguous", (True, False)) | ||
def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, deterministic=False, **kwargs): | ||
x_dtype = self.dtype if x_dtype is None else x_dtype | ||
rois_dtype = self.dtype if rois_dtype is None else rois_dtype | ||
@pytest.mark.parametrize( | ||
"x_dtype", | ||
( | ||
torch.float16, | ||
torch.float32, | ||
torch.float64, | ||
), | ||
ids=str, | ||
) | ||
def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, deterministic=False, **kwargs): | ||
if device == "mps" and x_dtype is torch.float64: | ||
pytest.skip("MPS does not support float64") | ||
|
||
rois_dtype = x_dtype if rois_dtype is None else rois_dtype | ||
|
||
tol = 1e-5 | ||
if x_dtype is torch.half: | ||
if device == "mps": | ||
tol = 5e-3 | ||
else: | ||
tol = 4e-3 | ||
|
||
pool_size = 5 | ||
# n_channels % (pool_size ** 2) == 0 required for PS operations. | ||
n_channels = 2 * (pool_size**2) | ||
|
@@ -120,10 +141,9 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, determ | |
# the following should be true whether we're running an autocast test or not. | ||
assert y.dtype == x.dtype | ||
gt_y = self.expected_fn( | ||
x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs | ||
x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=x_dtype, **kwargs | ||
) | ||
|
||
tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5 | ||
torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol) | ||
|
||
@pytest.mark.parametrize("device", cpu_and_cuda()) | ||
|
@@ -155,16 +175,19 @@ def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.floa | |
torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol) | ||
|
||
@pytest.mark.parametrize("seed", range(10)) | ||
@pytest.mark.parametrize("device", cpu_and_cuda()) | ||
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) | ||
@pytest.mark.parametrize("contiguous", (True, False)) | ||
def test_backward(self, seed, device, contiguous, deterministic=False): | ||
atol = self.mps_backward_atol if device == "mps" else 1e-05 | ||
dtype = self.mps_dtype if device == "mps" else self.dtype | ||
|
||
torch.random.manual_seed(seed) | ||
pool_size = 2 | ||
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=self.dtype, device=device, requires_grad=True) | ||
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=dtype, device=device, requires_grad=True) | ||
if not contiguous: | ||
x = x.permute(0, 1, 3, 2) | ||
rois = torch.tensor( | ||
[[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=self.dtype, device=device # format is (xyxy) | ||
[[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=dtype, device=device # format is (xyxy) | ||
) | ||
|
||
def func(z): | ||
|
@@ -173,9 +196,25 @@ def func(z): | |
script_func = self.get_script_fn(rois, pool_size) | ||
|
||
with DeterministicGuard(deterministic): | ||
gradcheck(func, (x,)) | ||
gradcheck(func, (x,), atol=atol) | ||
|
||
gradcheck(script_func, (x,), atol=atol) | ||
|
||
gradcheck(script_func, (x,)) | ||
@needs_mps | ||
def test_mps_error_inputs(self): | ||
pool_size = 2 | ||
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=torch.float16, device="mps", requires_grad=True) | ||
rois = torch.tensor( | ||
[[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=torch.float16, device="mps" # format is (xyxy) | ||
) | ||
|
||
def func(z): | ||
return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1) | ||
|
||
with pytest.raises( | ||
RuntimeError, match="MPS does not support (?:ps_)?roi_(?:align|pool)? backward with float16 inputs." | ||
): | ||
gradcheck(func, (x,)) | ||
|
||
@needs_cuda | ||
@pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) | ||
|
@@ -271,6 +310,8 @@ def test_jit_boxes_list(self): | |
|
||
|
||
class TestPSRoIPool(RoIOpTester): | ||
mps_backward_atol = 5e-2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @albanD , any thought regarding this For ref we typically use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The gradcheck is a bit tricky here as we usually only run it in fp64 precision to get accurate results. |
||
|
||
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): | ||
return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois) | ||
|
||
|
@@ -352,6 +393,8 @@ def bilinear_interpolate(data, y, x, snap_border=False): | |
|
||
|
||
class TestRoIAlign(RoIOpTester): | ||
mps_backward_atol = 6e-2 | ||
|
||
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs): | ||
return ops.RoIAlign( | ||
(pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned | ||
|
@@ -418,10 +461,11 @@ def test_boxes_shape(self): | |
self._helper_boxes_shape(ops.roi_align) | ||
|
||
@pytest.mark.parametrize("aligned", (True, False)) | ||
@pytest.mark.parametrize("device", cpu_and_cuda()) | ||
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) | ||
@pytest.mark.parametrize("x_dtype", (torch.float16, torch.float32, torch.float64), ids=str) | ||
@pytest.mark.parametrize("contiguous", (True, False)) | ||
@pytest.mark.parametrize("deterministic", (True, False)) | ||
def test_forward(self, device, contiguous, deterministic, aligned, x_dtype=None, rois_dtype=None): | ||
def test_forward(self, device, contiguous, deterministic, aligned, x_dtype, rois_dtype=None): | ||
if deterministic and device == "cpu": | ||
pytest.skip("cpu is always deterministic, don't retest") | ||
super().test_forward( | ||
|
@@ -450,7 +494,7 @@ def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype): | |
) | ||
|
||
@pytest.mark.parametrize("seed", range(10)) | ||
@pytest.mark.parametrize("device", cpu_and_cuda()) | ||
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) | ||
@pytest.mark.parametrize("contiguous", (True, False)) | ||
@pytest.mark.parametrize("deterministic", (True, False)) | ||
def test_backward(self, seed, device, contiguous, deterministic): | ||
|
@@ -537,6 +581,8 @@ def test_jit_boxes_list(self): | |
|
||
|
||
class TestPSRoIAlign(RoIOpTester): | ||
mps_backward_atol = 5e-2 | ||
|
||
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): | ||
return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois) | ||
|
||
|
@@ -705,40 +751,53 @@ def test_qnms(self, iou, scale, zero_point): | |
|
||
torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou)) | ||
|
||
@needs_cuda | ||
@pytest.mark.parametrize( | ||
"device", | ||
( | ||
pytest.param("cuda", marks=pytest.mark.needs_cuda), | ||
pytest.param("mps", marks=pytest.mark.needs_mps), | ||
), | ||
) | ||
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) | ||
def test_nms_cuda(self, iou, dtype=torch.float64): | ||
def test_nms_gpu(self, iou, device, dtype=torch.float64): | ||
dtype = torch.float32 if device == "mps" else dtype | ||
tol = 1e-3 if dtype is torch.half else 1e-5 | ||
err_msg = "NMS incompatible between CPU and CUDA for IoU={}" | ||
|
||
boxes, scores = self._create_tensors_with_iou(1000, iou) | ||
r_cpu = ops.nms(boxes, scores, iou) | ||
r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou) | ||
r_gpu = ops.nms(boxes.to(device), scores.to(device), iou) | ||
|
||
is_eq = torch.allclose(r_cpu, r_cuda.cpu()) | ||
is_eq = torch.allclose(r_cpu, r_gpu.cpu()) | ||
if not is_eq: | ||
# if the indices are not the same, ensure that it's because the scores | ||
# are duplicate | ||
is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol) | ||
is_eq = torch.allclose(scores[r_cpu], scores[r_gpu.cpu()], rtol=tol, atol=tol) | ||
assert is_eq, err_msg.format(iou) | ||
|
||
@needs_cuda | ||
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) | ||
@pytest.mark.parametrize("dtype", (torch.float, torch.half)) | ||
def test_autocast(self, iou, dtype): | ||
with torch.cuda.amp.autocast(): | ||
self.test_nms_cuda(iou=iou, dtype=dtype) | ||
self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda") | ||
|
||
@needs_cuda | ||
def test_nms_cuda_float16(self): | ||
@pytest.mark.parametrize( | ||
"device", | ||
( | ||
pytest.param("cuda", marks=pytest.mark.needs_cuda), | ||
pytest.param("mps", marks=pytest.mark.needs_mps), | ||
), | ||
) | ||
def test_nms_float16(self, device): | ||
boxes = torch.tensor( | ||
[ | ||
[285.3538, 185.5758, 1193.5110, 851.4551], | ||
[285.1472, 188.7374, 1192.4984, 851.0669], | ||
[279.2440, 197.9812, 1189.4746, 849.2019], | ||
] | ||
).cuda() | ||
scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda() | ||
).to(device) | ||
scores = torch.tensor([0.6370, 0.7569, 0.3966]).to(device) | ||
|
||
iou_thres = 0.2 | ||
keep32 = ops.nms(boxes, scores, iou_thres) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
constexpr int threadsPerBlock = 512; | ||
|
||
template <typename T> | ||
constexpr inline T ceil_div(T n, T m) { | ||
return (n + m - 1) / m; | ||
} |
Uh oh!
There was an error while loading. Please reload this page.