-
-
Notifications
You must be signed in to change notification settings - Fork 652
Description
🚀 Feature
Idea is to make configurable Metric
's reduction/gathering ops. By default, we are using our code, but user can globally override those functions. For example, if uses a custom unsupported distributed framework, or deals with asymmetry like here etc
EDIT:
When a metric is implemented methods like reset, update and compute are decorated with reinit__is_reduced and sync_all_reduce.
sync_all_reduce is implemented here:
ignite/ignite/metrics/metric.py
Lines 550 to 594 in 581f5b4
def sync_all_reduce(*attrs: Any) -> Callable: | |
"""Helper decorator for distributed configuration to collect instance attribute value | |
across all participating processes and apply the specified reduction operation. | |
See :doc:`metrics` on how to use it. | |
Args: | |
attrs: attribute names of decorated class | |
.. versionchanged:: 0.4.5 | |
- Ability to handle different reduction operations (SUM, MAX, MIN, PRODUCT). | |
""" | |
def wrapper(func: Callable) -> Callable: | |
@wraps(func) | |
def another_wrapper(self: Metric, *args: Any, **kwargs: Any) -> Callable: | |
if not isinstance(self, Metric): | |
raise RuntimeError( | |
"Decorator sync_all_reduce should be used on ignite.metric.Metric class methods only" | |
) | |
ws = idist.get_world_size() | |
if len(attrs) > 0 and not self._is_reduced: | |
if ws > 1: | |
for attr in attrs: | |
op_kwargs = {} | |
if ":" in attr: | |
attr, op = attr.split(":") | |
valid_ops = ["MIN", "MAX", "SUM", "PRODUCT"] | |
if op not in valid_ops: | |
raise ValueError(f"Reduction operation is not valid (expected : {valid_ops}, got: {op}") | |
op_kwargs["op"] = op | |
t = getattr(self, attr, None) | |
if t is not None: | |
t = idist.all_reduce(t, **op_kwargs) | |
self._is_reduced = True | |
setattr(self, attr, t) | |
else: | |
self._is_reduced = True | |
return func(self, *args, **kwargs) | |
return another_wrapper | |
setattr(wrapper, "_decorated", True) | |
return wrapper |
where we are using
idist.all_reduce(t, **op_kwargs)
So, the issue desctiption says:
Idea is to make configurable Metric's reduction/gathering ops. By default, we are using our code, but user can globally override those functions.
In other words, we would like to be able to call user custom all_reduce instead of idist.all_reduce
A tentative API for this feature
import ignite.distributed as idist
from ignite.metrics import set_all_reduce_fn, reset_all_reduce_fn, get_all_reduce_fn
from ignite.metrics import Accuracy
def my_all_reduce(tensor: Union[torch.Tensor, float], op: str = "SUM", **kwargs):
# ... custom implementation
pass
set_all_reduce_fn(my_all_reduce)
assert get_all_reduce_fn() == my_all_reduce
acc = Accuracy()
acc.update(...)
value = acc.compute() # should call my_all_reduce
reset_all_reduce_fn()
assert get_all_reduce_fn() == idist.all_reduce