-
-
Notifications
You must be signed in to change notification settings - Fork 657
Closed
Labels
Description
Let's check output_transform input argument type:
class Metric(...):
def __init__(
self,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
skip_unrolling: bool = False,
):
if not callable(output_transform):
raise TypeError(
"Argument output_transform should be callable, "
f"got {type(output_transform)}"
)
self._output_transform = output_transform
...TODO:
- Add above code
- Add a test checking that error is raised to https://github.com/pytorch/ignite/blob/master/tests/ignite/metrics/test_metric.py