Skip to content

Commit 12c8b87

Browse files
Ammend iou_fn tests
1 parent e178f36 commit 12c8b87

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

test/test_ops.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,8 +1367,6 @@ def assert_empty_loss(iou_fn, dtype, device):
13671367
loss = iou_fn(box1, box2, reduction="none")
13681368
assert loss.numel() == 0, f"{str(iou_fn)} for two empty box should be empty"
13691369

1370-
def assert_reduction_mode(iou_fn, box1, box2, reduction):
1371-
assert iou_fn(box1, box2, reduction) == ValueError
13721370

13731371
class TestGeneralizedBoxIouLoss:
13741372
# We refer to original test: https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py
@@ -1399,7 +1397,8 @@ def test_giou_loss(self, dtype, device):
13991397
# Test reduction value
14001398
# reduction value other than ["none", "mean", "sum"] should raise a ValueError
14011399
with pytest.raises(ValueError, match="Invalid"):
1402-
assert_reduction_mode(ops.generalized_box_iou_loss, box1s, box2s, reduction="xyz")
1400+
ops.generalized_box_iou_loss(box1s, box2s, reduction="xyz")
1401+
14031402

14041403
@pytest.mark.parametrize("device", cpu_and_gpu())
14051404
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@@ -1421,7 +1420,7 @@ def test_ciou_loss(self, dtype, device):
14211420
assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")
14221421

14231422
with pytest.raises(ValueError, match="Invalid"):
1424-
assert_reduction_mode(ops.complete_box_iou_loss, box1s, box2s, reduction="xyz")
1423+
ops.complete_box_iou_loss(box1s, box2s, reduction="xyz")
14251424

14261425
@pytest.mark.parametrize("device", cpu_and_gpu())
14271426
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@@ -1443,7 +1442,7 @@ def test_distance_iou_loss(self, dtype, device):
14431442
assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")
14441443

14451444
with pytest.raises(ValueError, match="Invalid"):
1446-
assert_reduction_mode(ops.distance_box_iou_loss, box1s, box2s, reduction="xyz")
1445+
ops.distance_box_iou_loss(box1s, box2s, reduction="xyz")
14471446

14481447
@pytest.mark.parametrize("device", cpu_and_gpu())
14491448
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@@ -1576,7 +1575,7 @@ def test_reduction_mode(self, device, dtype, reduction="xyz"):
15761575
torch.random.manual_seed(0)
15771576
inputs, targets = self._generate_diverse_input_target_pair(device=device, dtype=dtype)
15781577
with pytest.raises(ValueError, match="Invalid"):
1579-
ops.sigmoid_focal_loss(inputs, targets, 0.25, 2, reduction) == ValueError
1578+
ops.sigmoid_focal_loss(inputs, targets, 0.25, 2, reduction)
15801579

15811580

15821581

0 commit comments

Comments
 (0)