Skip to content

Added focal_loss_with_probs and full activation selection in Dice+FL. #8493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,8 +823,10 @@ def __init__(
self,
include_background: bool = True,
to_onehot_y: bool = False,
sigmoid: bool = False,
softmax: bool = False,
sigmoid_dice: bool = False,
softmax_dice: bool = False,
sigmoid_focal: bool = True,
softmax_focal: bool = False,
other_act: Callable | None = None,
squared_pred: bool = False,
jaccard: bool = False,
Expand All @@ -843,10 +845,10 @@ def __init__(
include_background: if False channel index 0 (background category) is excluded from the calculation.
to_onehot_y: whether to convert the ``target`` into the one-hot format,
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,
don't need to specify activation function for `FocalLoss`.
softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,
don't need to specify activation function for `FocalLoss`.
sigmoid_dice: if True, apply a sigmoid function to the prediction for the `DiceLoss`.
softmax_dice: if True, apply a softmax function to the prediction for the `DiceLoss`.
sigmoid_focal: if True, apply a sigmoid function to the prediction for `FocalLoss`.
softmax_focal: if True, apply a softmax function to the prediction for `FocalLoss`.
other_act: callable function to execute other activation layers, Defaults to ``None``.
for example: `other_act = torch.tanh`. only used by the `DiceLoss`, not for `FocalLoss`.
squared_pred: use squared versions of targets and predictions in the denominator or not.
Expand Down Expand Up @@ -878,8 +880,8 @@ def __init__(
self.dice = DiceLoss(
include_background=include_background,
to_onehot_y=False,
sigmoid=sigmoid,
softmax=softmax,
sigmoid=sigmoid_dice,
softmax=softmax_dice,
other_act=other_act,
squared_pred=squared_pred,
jaccard=jaccard,
Expand All @@ -896,6 +898,8 @@ def __init__(
weight=weight,
alpha=alpha,
reduction=reduction,
use_sigmoid=sigmoid_focal,
use_softmax=softmax_focal,
)
if lambda_dice < 0.0:
raise ValueError("lambda_dice should be no less than 0.0.")
Expand Down Expand Up @@ -953,8 +957,14 @@ class GeneralizedDiceFocalLoss(_Loss):
Defaults to True.
to_onehot_y: whether to convert the ``target`` into the one-hot format,
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
sigmoid (bool, optional): if True, apply a sigmoid function to the prediction. Defaults to False.
softmax (bool, optional): if True, apply a softmax function to the prediction. Defaults to False.
sigmoid_dice (bool, optional): if True, apply a sigmoid function to the prediction for `GeneralizedDiceLoss`.
Defaults to False.
softmax_dice (bool, optional): if True, apply a softmax function to the prediction for `GeneralizedDiceLoss`.
Defaults to False.
sigmoid_focal (bool, optional): if True, apply a sigmoid function to the prediction for `FocalLoss`.
Defaults to True.
softmax_focal (bool, optional): if True, apply a softmax function to the prediction for `FocalLoss`.
Defaults to False.
other_act (Optional[Callable], optional): callable function to execute other activation layers,
Defaults to ``None``. for example: `other_act = torch.tanh`.
only used by the `GeneralizedDiceLoss`, not for the `FocalLoss`.
Expand Down Expand Up @@ -987,8 +997,10 @@ def __init__(
self,
include_background: bool = True,
to_onehot_y: bool = False,
sigmoid: bool = False,
softmax: bool = False,
sigmoid_dice: bool = False,
softmax_dice: bool = False,
sigmoid_focal: bool = True,
softmax_focal: bool = False,
other_act: Callable | None = None,
w_type: Weight | str = Weight.SQUARE,
reduction: LossReduction | str = LossReduction.MEAN,
Expand All @@ -1004,8 +1016,8 @@ def __init__(
self.generalized_dice = GeneralizedDiceLoss(
include_background=include_background,
to_onehot_y=to_onehot_y,
sigmoid=sigmoid,
softmax=softmax,
sigmoid=sigmoid_dice,
softmax=softmax_dice,
other_act=other_act,
w_type=w_type,
reduction=reduction,
Expand All @@ -1019,6 +1031,8 @@ def __init__(
gamma=gamma,
weight=weight,
reduction=reduction,
use_sigmoid=sigmoid_focal,
use_softmax=softmax_focal,
)
if lambda_gdl < 0.0:
raise ValueError("lambda_gdl should be no less than 0.0.")
Expand Down
35 changes: 33 additions & 2 deletions monai/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
weight: Sequence[float] | float | int | torch.Tensor | None = None,
reduction: LossReduction | str = LossReduction.MEAN,
use_softmax: bool = False,
use_sigmoid: bool = True,
) -> None:
"""
Args:
Expand All @@ -96,7 +97,9 @@ def __init__(
- ``"sum"``: the output will be summed.

use_softmax: whether to use softmax to transform the original logits into probabilities.
If True, softmax is used. If False, sigmoid is used. Defaults to False.
If True, softmax is used. Defaults to False.
use_sigmoid: whether to use sigmoid to transform the original logits into probabilities.
If True, sigmoid is used. Defaults to True.

Example:
>>> import torch
Expand All @@ -113,6 +116,7 @@ def __init__(
self.alpha = alpha
self.weight = weight
self.use_softmax = use_softmax
self.use_sigmoid = use_sigmoid
weight = torch.as_tensor(weight) if weight is not None else None
self.register_buffer("class_weight", weight)
self.class_weight: None | torch.Tensor
Expand Down Expand Up @@ -161,8 +165,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
self.alpha = None
warnings.warn("`include_background=False`, `alpha` ignored when using softmax.")
loss = softmax_focal_loss(input, target, self.gamma, self.alpha)
else:
elif self.use_sigmoid:
loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha)
else:
loss = focal_loss_with_probs(input, target, self.gamma, self.alpha)

num_of_classes = target.shape[1]
if self.class_weight is not None and num_of_classes != 1:
Expand Down Expand Up @@ -253,3 +259,28 @@ def sigmoid_focal_loss(
loss = alpha_factor * loss

return loss


def focal_loss_with_probs(
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None
) -> torch.Tensor:
"""
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)

where p = x, pt = p if label is 1 or 1 - p if label is 0
"""
# Compute pt (probability of true class)
pt = torch.where(target == 1, input, 1 - input)

# Compute focal loss components
log_pt = torch.log(torch.clamp(pt, min=1e-8)) # Avoid log(0)
focal_factor = (1 - pt).pow(gamma) # (1 - pt)**gamma

loss: torch.Tensor = -focal_factor * log_pt

if alpha is not None:
# alpha if t==1; (1-alpha) if t==0
alpha_factor = torch.where(target == 1, alpha, 1 - alpha)
loss = alpha_factor * loss

return loss
6 changes: 2 additions & 4 deletions tests/losses/test_dice_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,12 @@ def test_result_no_onehot_no_bg(self, size, onehot):
for lambda_focal in [0.5, 1.0, 1.5]:
common_params = {
"include_background": False,
"softmax": True,
"to_onehot_y": onehot,
"reduction": reduction,
"weight": weight,
}
dice_focal = DiceFocalLoss(lambda_focal=lambda_focal, **common_params)
dice = DiceLoss(**common_params)
common_params.pop("softmax", None)
dice_focal = DiceFocalLoss(lambda_focal=lambda_focal, softmax_dice=True, **common_params)
dice = DiceLoss(softmax=True, **common_params)
focal = FocalLoss(**common_params)
result = dice_focal(pred, label)
expected_val = dice(pred, label) + lambda_focal * focal(pred, label)
Expand Down
63 changes: 41 additions & 22 deletions tests/losses/test_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

import unittest
from itertools import product

import numpy as np
import torch
Expand Down Expand Up @@ -205,59 +206,71 @@ def test_consistency_with_cross_entropy_classification_01(self):
self.assertNotAlmostEqual(max_error, 0.0, places=3)

def test_bin_seg_2d(self):
for use_softmax in [True, False]:
for use_softmax, use_sigmoid in product([True, False], repeat=2):
# define 2d examples
target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])
# add another dimension corresponding to the batch (batch size = 1 here)
target = target.unsqueeze(0) # shape (1, H, W)
pred_very_good = 100 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float() - 50.0
if not use_sigmoid and not use_softmax:
# The prediction here are probabilities, not logits.
pred_very_good = F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float()
else:
pred_very_good = 100 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float() - 50.0

# initialize the mean dice loss
loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax)
loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax, use_sigmoid=use_sigmoid)

# focal loss for pred_very_good should be close to 0
target = target.unsqueeze(1) # shape (1, 1, H, W)
focal_loss_good = float(loss(pred_very_good, target).cpu())
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

# with alpha
loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax)
loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax, use_sigmoid=use_sigmoid)
focal_loss_good = float(loss(pred_very_good, target).cpu())
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

def test_empty_class_2d(self):
for use_softmax in [True, False]:
for use_softmax, use_sigmoid in product([True, False], repeat=2):
num_classes = 2
# define 2d examples
target = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])
# add another dimension corresponding to the batch (batch size = 1 here)
target = target.unsqueeze(0) # shape (1, H, W)
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0
if not use_sigmoid and not use_softmax:
# The prediction here are probabilities, not logits.
pred_very_good = F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float()
else:
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0

# initialize the mean dice loss
loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax)
loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax, use_sigmoid=use_sigmoid)

# focal loss for pred_very_good should be close to 0
target = target.unsqueeze(1) # shape (1, 1, H, W)
focal_loss_good = float(loss(pred_very_good, target).cpu())
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

# with alpha
loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax)
loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax, use_sigmoid=use_sigmoid)
focal_loss_good = float(loss(pred_very_good, target).cpu())
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

def test_multi_class_seg_2d(self):
for use_softmax in [True, False]:
for use_softmax, use_sigmoid in product([True, False], repeat=2):
num_classes = 6 # labels 0 to 5
# define 2d examples
target = torch.tensor([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]])
# add another dimension corresponding to the batch (batch size = 1 here)
target = target.unsqueeze(0) # shape (1, H, W)
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0
if not use_sigmoid and not use_softmax:
# The prediction here are probabilities, not logits.
pred_very_good = F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float()
else:
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0
# initialize the mean dice loss
loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax)
loss_onehot = FocalLoss(to_onehot_y=False, use_softmax=use_softmax)
loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax, use_sigmoid=use_sigmoid)
loss_onehot = FocalLoss(to_onehot_y=False, use_softmax=use_softmax, use_sigmoid=use_sigmoid)

# focal loss for pred_very_good should be close to 0
target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2) # test one hot
Expand All @@ -270,15 +283,15 @@ def test_multi_class_seg_2d(self):
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

# with alpha
loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax)
loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax, use_sigmoid=use_sigmoid)
focal_loss_good = float(loss(pred_very_good, target).cpu())
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
loss_onehot = FocalLoss(to_onehot_y=False, alpha=0.5, use_softmax=use_softmax)
loss_onehot = FocalLoss(to_onehot_y=False, alpha=0.5, use_softmax=use_softmax, use_sigmoid=use_sigmoid)
focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu())
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

def test_bin_seg_3d(self):
for use_softmax in [True, False]:
for use_softmax, use_sigmoid in product([True, False], repeat=2):
num_classes = 2 # labels 0, 1
# define 3d examples
target = torch.tensor(
Expand All @@ -294,11 +307,17 @@ def test_bin_seg_3d(self):
# add another dimension corresponding to the batch (batch size = 1 here)
target = target.unsqueeze(0) # shape (1, H, W, D)
target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3) # test one hot
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3).float() - 500.0
if not use_sigmoid and not use_softmax:
# The prediction here are probabilities, not logits.
pred_very_good = target_one_hot.clone().float()
else:
pred_very_good = (
1000 * F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3).float() - 500.0
)

# initialize the mean dice loss
loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax)
loss_onehot = FocalLoss(to_onehot_y=False, use_softmax=use_softmax)
loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax, use_sigmoid=use_sigmoid)
loss_onehot = FocalLoss(to_onehot_y=False, use_softmax=use_softmax, use_sigmoid=use_sigmoid)

# focal loss for pred_very_good should be close to 0
target = target.unsqueeze(1) # shape (1, 1, H, W)
Expand All @@ -309,10 +328,10 @@ def test_bin_seg_3d(self):
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

# with alpha
loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax)
loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax, use_sigmoid=use_sigmoid)
focal_loss_good = float(loss(pred_very_good, target).cpu())
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
loss_onehot = FocalLoss(to_onehot_y=False, alpha=0.5, use_softmax=use_softmax)
loss_onehot = FocalLoss(to_onehot_y=False, alpha=0.5, use_softmax=use_softmax, use_sigmoid=use_sigmoid)
focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu())
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

Expand Down Expand Up @@ -369,8 +388,8 @@ def test_warnings(self):
loss(chn_input, chn_target)

def test_script(self):
for use_softmax in [True, False]:
loss = FocalLoss(use_softmax=use_softmax)
for use_softmax, use_sigmoid in product([True, False], repeat=2):
loss = FocalLoss(use_softmax=use_softmax, use_sigmoid=use_sigmoid)
test_input = torch.ones(2, 2, 8, 8)
test_script_save(loss, test_input, test_input)

Expand Down
Loading