Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
3484a59
Remove unnecessary code in BaseOutputHandler
sadra-barikbin Jan 22, 2022
2c8eed9
Merge branch 'pytorch:master' into master
sadra-barikbin Feb 3, 2022
ccf2364
Add ReduceLROnPlateauScheduler
sadra-barikbin Feb 3, 2022
7f7dae6
Fix indentation issue
sadra-barikbin Feb 3, 2022
896e482
Fix another indentation issue
sadra-barikbin Feb 3, 2022
cbc8d04
Fix PEP8 related issues
sadra-barikbin Feb 3, 2022
47b0622
Fix other PEP8 related issues
sadra-barikbin Feb 3, 2022
91d058e
Fix hopefully the last PEP8 related issue
sadra-barikbin Feb 3, 2022
9fd7d61
Fix hopefully the last PEP8 related issue
sadra-barikbin Feb 3, 2022
b7dc921
Merge branch 'pytorch:master' into master
sadra-barikbin Feb 3, 2022
e0644e3
Merge branch 'master' of https://github.com/sadra-barikbin/ignite
sadra-barikbin Feb 3, 2022
c95a2be
Remove ReduceLROnPlateau's specific params and add link to it
sadra-barikbin Feb 3, 2022
96554d0
Fix state_dict bug and add a test
sadra-barikbin Feb 5, 2022
145dabc
Merge branch 'pytorch:master' into master
sadra-barikbin Feb 9, 2022
0aee28a
Update docs
sadra-barikbin Feb 10, 2022
307803c
Merge branch 'master' into master
vfdev-5 Feb 14, 2022
0129572
Merge branch 'pytorch:master' into master
sadra-barikbin Feb 14, 2022
a17a5b2
Merge branch 'master' of https://github.com/pytorch/ignite
sadra-barikbin Feb 19, 2022
b3ea962
Add doctest and fix typo
sadra-barikbin Feb 20, 2022
e2e6831
Merge branch 'master' of https://github.com/pytorch/ignite
sadra-barikbin Feb 20, 2022
b88c9e1
Merge branch 'master' of https://github.com/pytorch/ignite
sadra-barikbin Feb 20, 2022
8d0ae3c
Merge branch 'master' of https://github.com/sadra-barikbin/ignite
sadra-barikbin Feb 20, 2022
408b271
Merge branch 'master' into master
vfdev-5 Feb 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ignite/contrib/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
LRScheduler,
ParamGroupScheduler,
PiecewiseLinear,
ReduceLROnPlateauScheduler,
create_lr_scheduler_with_warmup,
)
from ignite.handlers.time_profilers import BasicTimeProfiler, HandlersTimeProfiler
3 changes: 3 additions & 0 deletions ignite/contrib/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ParamGroupScheduler,
ParamScheduler,
PiecewiseLinear,
ReduceLROnPlateauScheduler,
create_lr_scheduler_with_warmup,
)

Expand All @@ -34,6 +35,7 @@
"PiecewiseLinear",
"CyclicalScheduler",
"create_lr_scheduler_with_warmup",
"ReduceLROnPlateauScheduler",
]

ConcatScheduler = ConcatScheduler
Expand All @@ -45,3 +47,4 @@
PiecewiseLinear = PiecewiseLinear
CyclicalScheduler = CyclicalScheduler
create_lr_scheduler_with_warmup = create_lr_scheduler_with_warmup
ReduceLROnPlateauScheduler = ReduceLROnPlateauScheduler
4 changes: 3 additions & 1 deletion ignite/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional

from ignite.engine import Engine
from ignite.engine.events import Events
Expand All @@ -16,6 +16,7 @@
ParamGroupScheduler,
ParamScheduler,
PiecewiseLinear,
ReduceLROnPlateauScheduler,
create_lr_scheduler_with_warmup,
)
from ignite.handlers.state_param_scheduler import (
Expand Down Expand Up @@ -62,6 +63,7 @@
"ExpStateScheduler",
"StepStateScheduler",
"MultiStepStateScheduler",
"ReduceLROnPlateauScheduler",
]


Expand Down
95 changes: 94 additions & 1 deletion ignite/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union, cast

import torch
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler
from torch.optim.optimizer import Optimizer

from ignite.engine import Engine
Expand Down Expand Up @@ -1393,6 +1393,99 @@ def simulate_values(cls, num_events: int, schedulers: List[_LRScheduler], **kwar
return values


class ReduceLROnPlateauScheduler(ParamScheduler):
"""Reduce LR when a metric stops improving.
Wrapper of torch.optim.lr_scheduler.ReduceLROnPlateau
<https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html>.

Args:
optimizer: Wrapped optimizer.
metric_name: metric whose improvement is monitored.
Must be attached the to same engine.
trainer: Trainer engine to log LR history in its
`state.output.param_history`. Is used if `save_history`
is true. Default: None.
save_history: Whether to save history or not. If true,
history will be logged in `trainer`'s `state.output.param_history`.
Default: False.
param_group_index: `optimizer`'s parameters group
to use. Default: None. Use all `optimizer`'s paramater groups.
**scheduler_kwargs: Keyword arguments to be passed to the wrapped
`ReduceLROnPlateau`.

Examples:

.. code-block:: python

# Metric 'metric-name' should surpass its best value by
# more than 1 unit after at most 2 epochs, otherwise LR
# would get multiplied by 0.5 .

scheduler = ReduceLROnPlateauScheduler(
optimizer,
metric_name="metric-name", mode="max",
factor=0.5, patience=1, threshold_mode='abs',
threshold=1, trainer=trainer
)

evaluator.add_event_handler(Events.COMPLETED, scheduler)

"""

def __init__(
self,
optimizer: Optimizer,
metric_name: str,
trainer: Optional[Engine] = None,
save_history: bool = False,
param_group_index: Optional[int] = None,
**scheduler_kwargs: Any,
):
super(ReduceLROnPlateauScheduler, self).__init__(
optimizer, "lr", save_history=save_history, param_group_index=param_group_index
)
self.metric_name = metric_name
self.trainer = trainer
self.optimizer = optimizer

if "min_lr" in scheduler_kwargs and param_group_index is not None:
min_lr = scheduler_kwargs["min_lr"]
if not isinstance(min_lr, float):
raise TypeError(f"When param_group_index is given, min_lr should be a float, but given {type(min_lr)}")
_min_lr = min_lr
min_lr = [0] * len(optimizer.param_groups)
min_lr[param_group_index] = _min_lr
else:
min_lr = 0
_scheduler_kwargs = scheduler_kwargs.copy()
_scheduler_kwargs["min_lr"] = min_lr

self.scheduler = ReduceLROnPlateau(optimizer, **_scheduler_kwargs)
self.scheduler._reduce_lr = self._reduce_lr # type: ignore[attr-defined]

self._state_attrs += ["metric_name", "scheduler"]

def __call__(self, engine: Engine, name: Optional[str] = None) -> None: # type: ignore[override]
if not hasattr(engine.state, "metrics") or self.metric_name not in engine.state.metrics:
raise ValueError(
"Argument engine should have in its 'state', attribute 'metrics' "
f"which itself has the metric {self.metric_name}."
)
self.scheduler.step(engine.state.metrics[self.metric_name])
super().__call__(self.trainer, name)

def get_param(self) -> Union[float, List[float]]:
lrs = [pg["lr"] for pg in self.optimizer_param_groups]
return lrs[0] if len(lrs) == 1 else lrs

def _reduce_lr(self, epoch: int) -> None:
for i, param_group in enumerate(self.optimizer_param_groups):
old_lr = float(param_group["lr"])
new_lr = max(old_lr * self.scheduler.factor, self.scheduler.min_lrs[i]) # type: ignore[attr-defined]
if old_lr - new_lr > self.scheduler.eps: # type: ignore[attr-defined]
param_group["lr"] = new_lr


def _get_fake_optimizer(
optimizer_cls: Optional[Union[Type[Optimizer], Type[torch.optim.SGD]]] = None, **kwargs: Any
) -> Union[Optimizer, torch.optim.SGD]:
Expand Down
71 changes: 71 additions & 0 deletions tests/ignite/handlers/test_param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ParamGroupScheduler,
ParamScheduler,
PiecewiseLinear,
ReduceLROnPlateauScheduler,
create_lr_scheduler_with_warmup,
)
from tests.ignite.contrib.handlers import MockFP16DeepSpeedZeroOptimizer
Expand Down Expand Up @@ -1302,3 +1303,73 @@ def save_lr(engine):
assert lrs == list(
map(pytest.approx, [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95])
)


def test_reduce_lr_on_plateau_scheduler():
tensor1 = torch.zeros([1], requires_grad=True)
tensor2 = torch.zeros([1], requires_grad=True)
optimizer = torch.optim.SGD([{"params": [tensor1]}, {"params": [tensor2]}], lr=1)

data = [0] * 8
max_epochs = 10

trainer = Engine(lambda engine, batch: None)

@trainer.on(Events.EPOCH_COMPLETED)
def evaluate():
evaluator.run(data)

scheduler = ReduceLROnPlateauScheduler(
optimizer,
metric_name="acc",
mode="max",
factor=0.5,
patience=1,
threshold_mode="abs",
threshold=1.99,
min_lr=1e-7,
save_history=True,
trainer=trainer,
param_group_index=0,
)
evaluator = Engine(lambda engine, batch: None)
evaluator.state.metrics = {"acc": 0.0}
generate_acc = iter([3, 7, 7, 9, 10, 11, 8, 8, 4, 7])

@evaluator.on(Events.COMPLETED)
def set_acc():
evaluator.state.metrics["acc"] = next(generate_acc)

evaluator.add_event_handler(Events.COMPLETED, scheduler)

trainer.run(data, max_epochs=max_epochs)

lrs = [param[0] for param in trainer.state.param_history["lr"]]
assert lrs == list(
map(
pytest.approx,
[1, 1, 1, 1, 1, 1, 1, 0.5, 0.5, 0.25],
)
)
assert optimizer.param_groups[1]["lr"] == 1


def test_reduce_lr_on_plateau_scheduler_asserts():
tensor1 = torch.zeros([1], requires_grad=True)
tensor2 = torch.zeros([1], requires_grad=True)
optimizer = torch.optim.SGD([{"params": [tensor1]}, {"params": [tensor2]}], lr=1)

with pytest.raises(TypeError, match=r"When param_group_index is given, min_lr should be a float, but given"):
ReduceLROnPlateauScheduler(
optimizer,
metric_name="acc",
min_lr=[1e-7, 1e-8],
param_group_index=0,
)

with pytest.raises(
ValueError, match=r"Argument engine should have in its 'state', attribute 'metrics' which itself has the metric"
):
scheduler = ReduceLROnPlateauScheduler(optimizer, metric_name="acc")
evaluator = Engine(lambda engine, batch: None)
scheduler(evaluator)