diff --git a/monai/losses/dice.py b/monai/losses/dice.py index ed88100edd..26ff25ca7f 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -823,8 +823,10 @@ def __init__( self, include_background: bool = True, to_onehot_y: bool = False, - sigmoid: bool = False, - softmax: bool = False, + sigmoid_dice: bool = False, + softmax_dice: bool = False, + sigmoid_focal: bool = True, + softmax_focal: bool = False, other_act: Callable | None = None, squared_pred: bool = False, jaccard: bool = False, @@ -843,10 +845,10 @@ def __init__( include_background: if False channel index 0 (background category) is excluded from the calculation. to_onehot_y: whether to convert the ``target`` into the one-hot format, using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False. - sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`, - don't need to specify activation function for `FocalLoss`. - softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`, - don't need to specify activation function for `FocalLoss`. + sigmoid_dice: if True, apply a sigmoid function to the prediction for the `DiceLoss`. + softmax_dice: if True, apply a softmax function to the prediction for the `DiceLoss`. + sigmoid_focal: if True, apply a sigmoid function to the prediction for `FocalLoss`. + softmax_focal: if True, apply a softmax function to the prediction for `FocalLoss`. other_act: callable function to execute other activation layers, Defaults to ``None``. for example: `other_act = torch.tanh`. only used by the `DiceLoss`, not for `FocalLoss`. squared_pred: use squared versions of targets and predictions in the denominator or not. @@ -878,8 +880,8 @@ def __init__( self.dice = DiceLoss( include_background=include_background, to_onehot_y=False, - sigmoid=sigmoid, - softmax=softmax, + sigmoid=sigmoid_dice, + softmax=softmax_dice, other_act=other_act, squared_pred=squared_pred, jaccard=jaccard, @@ -896,6 +898,8 @@ def __init__( weight=weight, alpha=alpha, reduction=reduction, + use_sigmoid=sigmoid_focal, + use_softmax=softmax_focal, ) if lambda_dice < 0.0: raise ValueError("lambda_dice should be no less than 0.0.") @@ -953,8 +957,14 @@ class GeneralizedDiceFocalLoss(_Loss): Defaults to True. to_onehot_y: whether to convert the ``target`` into the one-hot format, using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False. - sigmoid (bool, optional): if True, apply a sigmoid function to the prediction. Defaults to False. - softmax (bool, optional): if True, apply a softmax function to the prediction. Defaults to False. + sigmoid_dice (bool, optional): if True, apply a sigmoid function to the prediction for `GeneralizedDiceLoss`. + Defaults to False. + softmax_dice (bool, optional): if True, apply a softmax function to the prediction for `GeneralizedDiceLoss`. + Defaults to False. + sigmoid_focal (bool, optional): if True, apply a sigmoid function to the prediction for `FocalLoss`. + Defaults to True. + softmax_focal (bool, optional): if True, apply a softmax function to the prediction for `FocalLoss`. + Defaults to False. other_act (Optional[Callable], optional): callable function to execute other activation layers, Defaults to ``None``. for example: `other_act = torch.tanh`. only used by the `GeneralizedDiceLoss`, not for the `FocalLoss`. @@ -987,8 +997,10 @@ def __init__( self, include_background: bool = True, to_onehot_y: bool = False, - sigmoid: bool = False, - softmax: bool = False, + sigmoid_dice: bool = False, + softmax_dice: bool = False, + sigmoid_focal: bool = True, + softmax_focal: bool = False, other_act: Callable | None = None, w_type: Weight | str = Weight.SQUARE, reduction: LossReduction | str = LossReduction.MEAN, @@ -1004,8 +1016,8 @@ def __init__( self.generalized_dice = GeneralizedDiceLoss( include_background=include_background, to_onehot_y=to_onehot_y, - sigmoid=sigmoid, - softmax=softmax, + sigmoid=sigmoid_dice, + softmax=softmax_dice, other_act=other_act, w_type=w_type, reduction=reduction, @@ -1019,6 +1031,8 @@ def __init__( gamma=gamma, weight=weight, reduction=reduction, + use_sigmoid=sigmoid_focal, + use_softmax=softmax_focal, ) if lambda_gdl < 0.0: raise ValueError("lambda_gdl should be no less than 0.0.") diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 28d1c0cdc9..2cd35c4b60 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -74,6 +74,7 @@ def __init__( weight: Sequence[float] | float | int | torch.Tensor | None = None, reduction: LossReduction | str = LossReduction.MEAN, use_softmax: bool = False, + use_sigmoid: bool = True, ) -> None: """ Args: @@ -96,7 +97,9 @@ def __init__( - ``"sum"``: the output will be summed. use_softmax: whether to use softmax to transform the original logits into probabilities. - If True, softmax is used. If False, sigmoid is used. Defaults to False. + If True, softmax is used. Defaults to False. + use_sigmoid: whether to use sigmoid to transform the original logits into probabilities. + If True, sigmoid is used. Defaults to True. Example: >>> import torch @@ -113,6 +116,7 @@ def __init__( self.alpha = alpha self.weight = weight self.use_softmax = use_softmax + self.use_sigmoid = use_sigmoid weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", weight) self.class_weight: None | torch.Tensor @@ -161,8 +165,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: self.alpha = None warnings.warn("`include_background=False`, `alpha` ignored when using softmax.") loss = softmax_focal_loss(input, target, self.gamma, self.alpha) - else: + elif self.use_sigmoid: loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha) + else: + loss = focal_loss_with_probs(input, target, self.gamma, self.alpha) num_of_classes = target.shape[1] if self.class_weight is not None and num_of_classes != 1: @@ -253,3 +259,28 @@ def sigmoid_focal_loss( loss = alpha_factor * loss return loss + + +def focal_loss_with_probs( + input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None +) -> torch.Tensor: + """ + FL(pt) = -alpha * (1 - pt)**gamma * log(pt) + + where p = x, pt = p if label is 1 or 1 - p if label is 0 + """ + # Compute pt (probability of true class) + pt = torch.where(target == 1, input, 1 - input) + + # Compute focal loss components + log_pt = torch.log(torch.clamp(pt, min=1e-8)) # Avoid log(0) + focal_factor = (1 - pt).pow(gamma) # (1 - pt)**gamma + + loss: torch.Tensor = -focal_factor * log_pt + + if alpha is not None: + # alpha if t==1; (1-alpha) if t==0 + alpha_factor = torch.where(target == 1, alpha, 1 - alpha) + loss = alpha_factor * loss + + return loss diff --git a/tests/losses/test_dice_focal_loss.py b/tests/losses/test_dice_focal_loss.py index 98ea475ded..4d27f34a7f 100644 --- a/tests/losses/test_dice_focal_loss.py +++ b/tests/losses/test_dice_focal_loss.py @@ -53,14 +53,12 @@ def test_result_no_onehot_no_bg(self, size, onehot): for lambda_focal in [0.5, 1.0, 1.5]: common_params = { "include_background": False, - "softmax": True, "to_onehot_y": onehot, "reduction": reduction, "weight": weight, } - dice_focal = DiceFocalLoss(lambda_focal=lambda_focal, **common_params) - dice = DiceLoss(**common_params) - common_params.pop("softmax", None) + dice_focal = DiceFocalLoss(lambda_focal=lambda_focal, softmax_dice=True, **common_params) + dice = DiceLoss(softmax=True, **common_params) focal = FocalLoss(**common_params) result = dice_focal(pred, label) expected_val = dice(pred, label) + lambda_focal * focal(pred, label) diff --git a/tests/losses/test_focal_loss.py b/tests/losses/test_focal_loss.py index e7f447d90e..9fb9299d02 100644 --- a/tests/losses/test_focal_loss.py +++ b/tests/losses/test_focal_loss.py @@ -12,6 +12,7 @@ from __future__ import annotations import unittest +from itertools import product import numpy as np import torch @@ -205,15 +206,19 @@ def test_consistency_with_cross_entropy_classification_01(self): self.assertNotAlmostEqual(max_error, 0.0, places=3) def test_bin_seg_2d(self): - for use_softmax in [True, False]: + for use_softmax, use_sigmoid in product([True, False], repeat=2): # define 2d examples target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) # add another dimension corresponding to the batch (batch size = 1 here) target = target.unsqueeze(0) # shape (1, H, W) - pred_very_good = 100 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float() - 50.0 + if not use_sigmoid and not use_softmax: + # The prediction here are probabilities, not logits. + pred_very_good = F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float() + else: + pred_very_good = 100 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float() - 50.0 # initialize the mean dice loss - loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax) + loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax, use_sigmoid=use_sigmoid) # focal loss for pred_very_good should be close to 0 target = target.unsqueeze(1) # shape (1, 1, H, W) @@ -221,21 +226,25 @@ def test_bin_seg_2d(self): self.assertAlmostEqual(focal_loss_good, 0.0, places=3) # with alpha - loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax) + loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax, use_sigmoid=use_sigmoid) focal_loss_good = float(loss(pred_very_good, target).cpu()) self.assertAlmostEqual(focal_loss_good, 0.0, places=3) def test_empty_class_2d(self): - for use_softmax in [True, False]: + for use_softmax, use_sigmoid in product([True, False], repeat=2): num_classes = 2 # define 2d examples target = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]) # add another dimension corresponding to the batch (batch size = 1 here) target = target.unsqueeze(0) # shape (1, H, W) - pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0 + if not use_sigmoid and not use_softmax: + # The prediction here are probabilities, not logits. + pred_very_good = F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() + else: + pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0 # initialize the mean dice loss - loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax) + loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax, use_sigmoid=use_sigmoid) # focal loss for pred_very_good should be close to 0 target = target.unsqueeze(1) # shape (1, 1, H, W) @@ -243,21 +252,25 @@ def test_empty_class_2d(self): self.assertAlmostEqual(focal_loss_good, 0.0, places=3) # with alpha - loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax) + loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax, use_sigmoid=use_sigmoid) focal_loss_good = float(loss(pred_very_good, target).cpu()) self.assertAlmostEqual(focal_loss_good, 0.0, places=3) def test_multi_class_seg_2d(self): - for use_softmax in [True, False]: + for use_softmax, use_sigmoid in product([True, False], repeat=2): num_classes = 6 # labels 0 to 5 # define 2d examples target = torch.tensor([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]) # add another dimension corresponding to the batch (batch size = 1 here) target = target.unsqueeze(0) # shape (1, H, W) - pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0 + if not use_sigmoid and not use_softmax: + # The prediction here are probabilities, not logits. + pred_very_good = F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() + else: + pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0 # initialize the mean dice loss - loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax) - loss_onehot = FocalLoss(to_onehot_y=False, use_softmax=use_softmax) + loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax, use_sigmoid=use_sigmoid) + loss_onehot = FocalLoss(to_onehot_y=False, use_softmax=use_softmax, use_sigmoid=use_sigmoid) # focal loss for pred_very_good should be close to 0 target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2) # test one hot @@ -270,15 +283,15 @@ def test_multi_class_seg_2d(self): self.assertAlmostEqual(focal_loss_good, 0.0, places=3) # with alpha - loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax) + loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax, use_sigmoid=use_sigmoid) focal_loss_good = float(loss(pred_very_good, target).cpu()) self.assertAlmostEqual(focal_loss_good, 0.0, places=3) - loss_onehot = FocalLoss(to_onehot_y=False, alpha=0.5, use_softmax=use_softmax) + loss_onehot = FocalLoss(to_onehot_y=False, alpha=0.5, use_softmax=use_softmax, use_sigmoid=use_sigmoid) focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu()) self.assertAlmostEqual(focal_loss_good, 0.0, places=3) def test_bin_seg_3d(self): - for use_softmax in [True, False]: + for use_softmax, use_sigmoid in product([True, False], repeat=2): num_classes = 2 # labels 0, 1 # define 3d examples target = torch.tensor( @@ -294,11 +307,17 @@ def test_bin_seg_3d(self): # add another dimension corresponding to the batch (batch size = 1 here) target = target.unsqueeze(0) # shape (1, H, W, D) target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3) # test one hot - pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3).float() - 500.0 + if not use_sigmoid and not use_softmax: + # The prediction here are probabilities, not logits. + pred_very_good = target_one_hot.clone().float() + else: + pred_very_good = ( + 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3).float() - 500.0 + ) # initialize the mean dice loss - loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax) - loss_onehot = FocalLoss(to_onehot_y=False, use_softmax=use_softmax) + loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax, use_sigmoid=use_sigmoid) + loss_onehot = FocalLoss(to_onehot_y=False, use_softmax=use_softmax, use_sigmoid=use_sigmoid) # focal loss for pred_very_good should be close to 0 target = target.unsqueeze(1) # shape (1, 1, H, W) @@ -309,10 +328,10 @@ def test_bin_seg_3d(self): self.assertAlmostEqual(focal_loss_good, 0.0, places=3) # with alpha - loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax) + loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax, use_sigmoid=use_sigmoid) focal_loss_good = float(loss(pred_very_good, target).cpu()) self.assertAlmostEqual(focal_loss_good, 0.0, places=3) - loss_onehot = FocalLoss(to_onehot_y=False, alpha=0.5, use_softmax=use_softmax) + loss_onehot = FocalLoss(to_onehot_y=False, alpha=0.5, use_softmax=use_softmax, use_sigmoid=use_sigmoid) focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu()) self.assertAlmostEqual(focal_loss_good, 0.0, places=3) @@ -369,8 +388,8 @@ def test_warnings(self): loss(chn_input, chn_target) def test_script(self): - for use_softmax in [True, False]: - loss = FocalLoss(use_softmax=use_softmax) + for use_softmax, use_sigmoid in product([True, False], repeat=2): + loss = FocalLoss(use_softmax=use_softmax, use_sigmoid=use_sigmoid) test_input = torch.ones(2, 2, 8, 8) test_script_save(loss, test_input, test_input)