Skip to content

Add output_transform type check in Metric class #3352

@vfdev-5

Description

@vfdev-5

Let's check output_transform input argument type:

class Metric(...):

    def __init__(
        self,
        output_transform: Callable = lambda x: x,
        device: Union[str, torch.device] = torch.device("cpu"),
        skip_unrolling: bool = False,
    ):
        if not callable(output_transform):
            raise TypeError(
                "Argument output_transform should be callable, "
                f"got {type(output_transform)}"
            )
        self._output_transform = output_transform
        ...

TODO:

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions