Skip to content

Potentially faulty usage of sigmoid argument in Masked Dice #8655

@marakrup

Description

@marakrup

Hi there!

When using MaskedDice, my UNET doesn't train as well as when using the normal DICE. Upon closer inspection, I have noticed the following, which I believe to be an inconsistency in the usage of the "sigmoid" argument when initializing the loss (please correct me if I'm wrong :)).

The DICE loss, and thus also MaskedDiceloss, accept the argument "sigmoid=True", allowing the user to pass in model output logits as opposed to model output probabilities.

When using DICE loss, the sigmoid is handled correctly. However, when using the MaskedDice, the sigmoid is applied after masking the input and target values.

If I am reading the code correctly, MaskedDice applies the mask to the model outputs and targets by setting all to be masked values equal to zero, and then passes on these masked model outputs and targets to the regular DICE loss. The sigmoid will then be regularly applied to the masked model outputs in the normal DICE loss's forward function, if the argument "sigmoid" has been set to "True" during initialization.

Notably, all model output values set to zero during masking become 0.5 when forwarded through the sigmoid in the normal DICE forward function. This is, if I understand correctly, not the behavior a user would expect.

To fix this, I'd suggest setting all values in the model output to -inf when masking whenever the argument "sigmoid" is set to "True". Or alternatively, apply the sigmoid in the forward function of the MaskedDice, and initialize the normal DICE with sigmoid=False.

I hope my explanation was understandable – maybe I am also misreading the code. Let me know. :)

Kind regards!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions