Skip to content

feat: Added support of Poly Loss #6457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ The following vision-specific loss functions are implemented:
distance_box_iou_loss
generalized_box_iou_loss
sigmoid_focal_loss
stochastic_depth
poly_loss


Layers
Expand Down
2 changes: 2 additions & 0 deletions torchvision/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .focal_loss import sigmoid_focal_loss
from .giou_loss import generalized_box_iou_loss
from .misc import Conv2dNormActivation, Conv3dNormActivation, FrozenBatchNorm2d, MLP, Permute, SqueezeExcitation
from .poly_loss import poly_loss
from .poolers import MultiScaleRoIAlign
from .ps_roi_align import ps_roi_align, PSRoIAlign
from .ps_roi_pool import ps_roi_pool, PSRoIPool
Expand Down Expand Up @@ -70,4 +71,5 @@
"DropBlock2d",
"drop_block3d",
"DropBlock3d",
"poly_loss",
]
68 changes: 68 additions & 0 deletions torchvision/ops/poly_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Optional

import torch
from torch import Tensor

from ..utils import _log_api_usage_once


def poly_loss(
x: Tensor,
target: Tensor,
eps: float = 2.0,
weight: Optional[Tensor] = None,
ignore_index: int = -100,
reduction: str = "mean",
) -> Tensor:
"""Implements the Poly1 loss from `"PolyLoss: A Polynomial Expansion Perspective of Classification Loss
Functions" <https://arxiv.org/pdf/2204.12511.pdf>`_.

Args:
x (Tensor[N, K, ...]): predicted probability
target (Tensor[N, K, ...]): target probability
eps (float, optional): epsilon 1 from the paper
weight (Tensor[K], optional): manual rescaling of each class
ignore_index (int, optional): specifies target value that is ignored and do not contribute to gradient
reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
``'none'``: No reduction will be applied to the output.
``'mean'``: The output will be averaged.
``'sum'``: The output will be summed. Default: ``'none'``.

Returns:
Tensor: loss reduced with `reduction` method
"""
# Original implementation from https://github.com/frgfm/Holocron/blob/main/holocron/nn/functional.py
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(poly_loss)
# log(P[class]) = log_softmax(score)[class]
logpt = F.log_softmax(x, dim=1)

# Compute pt and logpt only for target classes (the remaining will have a 0 coefficient)
logpt = logpt.transpose(1, 0).flatten(1).gather(0, target.view(1, -1)).squeeze()
# Ignore index (set loss contribution to 0)
valid_idxs = torch.ones(target.view(-1).shape[0], dtype=torch.bool, device=x.device)
if ignore_index >= 0 and ignore_index < x.shape[1]:
valid_idxs[target.view(-1) == ignore_index] = False

# Get P(class)
loss = -1 * logpt + eps * (1 - logpt.exp())

# Weight
if weight is not None:
# Tensor type
if weight.type() != x.data.type():
weight = weight.type_as(x.data)
logpt = weight.gather(0, target.data.view(-1)) * logpt

# Loss reduction
if reduction == "sum":
loss = loss[valid_idxs].sum()
elif reduction == "mean":
loss = loss[valid_idxs].mean()
elif reduction == "none":
# if no reduction, reshape tensor like target
loss = loss.view(*target.shape)
else:
raise ValueError(f"invalid value for arg 'reduction': {reduction}")

return loss