From 9df51694db4e86c1852b0f66eb4dc98adddd5c57 Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Sat, 20 Aug 2022 13:10:52 +0200 Subject: [PATCH 1/3] feat: Added poly_loss --- torchvision/ops/__init__.py | 2 ++ torchvision/ops/poly_loss.py | 66 ++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 torchvision/ops/poly_loss.py diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index cd711578a6c..f8d7716f242 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -20,6 +20,7 @@ from .focal_loss import sigmoid_focal_loss from .giou_loss import generalized_box_iou_loss from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation +from .poly_loss import poly_loss from .poolers import MultiScaleRoIAlign from .ps_roi_align import ps_roi_align, PSRoIAlign from .ps_roi_pool import ps_roi_pool, PSRoIPool @@ -68,4 +69,5 @@ "DropBlock2d", "drop_block3d", "DropBlock3d", + "poly_loss", ] diff --git a/torchvision/ops/poly_loss.py b/torchvision/ops/poly_loss.py new file mode 100644 index 00000000000..597b0cbdaae --- /dev/null +++ b/torchvision/ops/poly_loss.py @@ -0,0 +1,66 @@ +from typing import Optional + +import torch +from torch import Tensor + +from ..utils import _log_api_usage_once + + +def poly_loss( + x: Tensor, + target: Tensor, + eps: float = 2.0, + weight: Optional[Tensor] = None, + ignore_index: int = -100, + reduction: str = "mean", +) -> Tensor: + """Implements the Poly1 loss from `"PolyLoss: A Polynomial Expansion Perspective of Classification Loss + Functions" `_. + + Args: + x (Tensor[N, K, ...]): predicted probability + target (Tensor[N, K, ...]): target probability + eps (float, optional): epsilon 1 from the paper + weight (Tensor[K], optional): manual rescaling of each class + ignore_index (int, optional): specifies target value that is ignored and do not contribute to gradient + reduction (string): ``'none'`` | ``'mean'`` | ``'sum'`` + ``'none'``: No reduction will be applied to the output. + ``'mean'``: The output will be averaged. + ``'sum'``: The output will be summed. Default: ``'none'``. + + Returns: + Tensor: loss reduced with `reduction` method + """ + # Original implementation from https://github.com/frgfm/Holocron/blob/main/holocron/nn/functional.py + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(poly_loss) + # log(P[class]) = log_softmax(score)[class] + logpt = F.log_softmax(x, dim=1) + + # Compute pt and logpt only for target classes (the remaining will have a 0 coefficient) + logpt = logpt.transpose(1, 0).flatten(1).gather(0, target.view(1, -1)).squeeze() + # Ignore index (set loss contribution to 0) + valid_idxs = torch.ones(target.view(-1).shape[0], dtype=torch.bool, device=x.device) + if ignore_index >= 0 and ignore_index < x.shape[1]: + valid_idxs[target.view(-1) == ignore_index] = False + + # Get P(class) + loss = -1 * logpt + eps * (1 - logpt.exp()) + + # Weight + if weight is not None: + # Tensor type + if weight.type() != x.data.type(): + weight = weight.type_as(x.data) + logpt = weight.gather(0, target.data.view(-1)) * logpt + + # Loss reduction + if reduction == "sum": + loss = loss[valid_idxs].sum() + elif reduction == "mean": + loss = loss[valid_idxs].mean() + else: + # if no reduction, reshape tensor like target + loss = loss.view(*target.shape) + + return loss From e20926eeaad02cf24f769d14347b22b41900ba97 Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Sat, 20 Aug 2022 13:11:02 +0200 Subject: [PATCH 2/3] docs: Added entry in the documentation --- docs/source/ops.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/ops.rst b/docs/source/ops.rst index e35981854e5..bfe2c85ca6c 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -37,6 +37,7 @@ Operators roi_pool sigmoid_focal_loss stochastic_depth + poly_loss .. autosummary:: :toctree: generated/ From 0dd221867542b591baa44629c9df12e7aea73d90 Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Fri, 16 Sep 2022 11:33:09 +0200 Subject: [PATCH 3/3] feat: Added reduction value check --- torchvision/ops/poly_loss.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/ops/poly_loss.py b/torchvision/ops/poly_loss.py index 597b0cbdaae..4e446163c6a 100644 --- a/torchvision/ops/poly_loss.py +++ b/torchvision/ops/poly_loss.py @@ -59,8 +59,10 @@ def poly_loss( loss = loss[valid_idxs].sum() elif reduction == "mean": loss = loss[valid_idxs].mean() - else: + elif reduction == "none": # if no reduction, reshape tensor like target loss = loss.view(*target.shape) + else: + raise ValueError(f"invalid value for arg 'reduction': {reduction}") return loss