Skip to content

Commit ac0cef0

Browse files
NicolasHugpmeierdatumbox
authored andcommitted
[fbsync] Handle invalid reduction values (#6675)
Summary: * Add ValueError * Add tests for ValueError * Add tests for ValueError * Add ValueError * Change to if/else * Ammend iou_fn tests * Move code excerpt * Format tests Reviewed By: datumbox Differential Revision: D40138724 fbshipit-source-id: 56c742a8c2ff80f2f51cba4cb3156835ed250653 Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 7516e02 commit ac0cef0

File tree

5 files changed

+56
-7
lines changed

5 files changed

+56
-7
lines changed

test/test_ops.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,6 +1394,11 @@ def test_giou_loss(self, dtype, device):
13941394
assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 2.5, device=device, reduction="sum")
13951395
assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 1.25, device=device, reduction="mean")
13961396

1397+
# Test reduction value
1398+
# reduction value other than ["none", "mean", "sum"] should raise a ValueError
1399+
with pytest.raises(ValueError, match="Invalid"):
1400+
ops.generalized_box_iou_loss(box1s, box2s, reduction="xyz")
1401+
13971402
@pytest.mark.parametrize("device", cpu_and_gpu())
13981403
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
13991404
def test_empty_inputs(self, dtype, device):
@@ -1413,6 +1418,9 @@ def test_ciou_loss(self, dtype, device):
14131418
assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean")
14141419
assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")
14151420

1421+
with pytest.raises(ValueError, match="Invalid"):
1422+
ops.complete_box_iou_loss(box1s, box2s, reduction="xyz")
1423+
14161424
@pytest.mark.parametrize("device", cpu_and_gpu())
14171425
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
14181426
def test_empty_inputs(self, dtype, device):
@@ -1432,6 +1440,9 @@ def test_distance_iou_loss(self, dtype, device):
14321440
assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean")
14331441
assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")
14341442

1443+
with pytest.raises(ValueError, match="Invalid"):
1444+
ops.distance_box_iou_loss(box1s, box2s, reduction="xyz")
1445+
14351446
@pytest.mark.parametrize("device", cpu_and_gpu())
14361447
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
14371448
def test_empty_distance_iou_inputs(self, dtype, device):
@@ -1554,6 +1565,17 @@ def test_jit(self, alpha, gamma, reduction, device, dtype, seed):
15541565
tol = 1e-3 if dtype is torch.half else 1e-5
15551566
torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol)
15561567

1568+
# Raise ValueError for anonymous reduction mode
1569+
@pytest.mark.parametrize("device", cpu_and_gpu())
1570+
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1571+
def test_reduction_mode(self, device, dtype, reduction="xyz"):
1572+
if device == "cpu" and dtype is torch.half:
1573+
pytest.skip("Currently torch.half is not fully supported on cpu")
1574+
torch.random.manual_seed(0)
1575+
inputs, targets = self._generate_diverse_input_target_pair(device=device, dtype=dtype)
1576+
with pytest.raises(ValueError, match="Invalid"):
1577+
ops.sigmoid_focal_loss(inputs, targets, 0.25, 2, reduction)
1578+
15571579

15581580
class TestMasksToBoxes:
15591581
def test_masks_box(self):

torchvision/ops/ciou_loss.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,16 @@ def complete_box_iou_loss(
6363
alpha = v / (1 - iou + v + eps)
6464

6565
loss = diou_loss + alpha * v
66-
if reduction == "mean":
66+
67+
# Check reduction option and return loss accordingly
68+
if reduction == "none":
69+
pass
70+
elif reduction == "mean":
6771
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
6872
elif reduction == "sum":
6973
loss = loss.sum()
70-
74+
else:
75+
raise ValueError(
76+
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
77+
)
7178
return loss

torchvision/ops/diou_loss.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,17 @@ def distance_box_iou_loss(
5050

5151
loss, _ = _diou_iou_loss(boxes1, boxes2, eps)
5252

53-
if reduction == "mean":
53+
# Check reduction option and return loss accordingly
54+
if reduction == "none":
55+
pass
56+
elif reduction == "mean":
5457
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
5558
elif reduction == "sum":
5659
loss = loss.sum()
60+
else:
61+
raise ValueError(
62+
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
63+
)
5764
return loss
5865

5966

torchvision/ops/focal_loss.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def sigmoid_focal_loss(
3232
Loss tensor with the reduction option applied.
3333
"""
3434
# Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
35+
3536
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
3637
_log_api_usage_once(sigmoid_focal_loss)
3738
p = torch.sigmoid(inputs)
@@ -43,9 +44,15 @@ def sigmoid_focal_loss(
4344
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
4445
loss = alpha_t * loss
4546

46-
if reduction == "mean":
47+
# Check reduction option and return loss accordingly
48+
if reduction == "none":
49+
pass
50+
elif reduction == "mean":
4751
loss = loss.mean()
4852
elif reduction == "sum":
4953
loss = loss.sum()
50-
54+
else:
55+
raise ValueError(
56+
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
57+
)
5158
return loss

torchvision/ops/giou_loss.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,15 @@ def generalized_box_iou_loss(
6262

6363
loss = 1 - miouk
6464

65-
if reduction == "mean":
65+
# Check reduction option and return loss accordingly
66+
if reduction == "none":
67+
pass
68+
elif reduction == "mean":
6669
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
6770
elif reduction == "sum":
6871
loss = loss.sum()
69-
72+
else:
73+
raise ValueError(
74+
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
75+
)
7076
return loss

0 commit comments

Comments
 (0)