Skip to content

Conversation

@AlexanderChaptykov
Copy link
Contributor

@AlexanderChaptykov AlexanderChaptykov commented May 8, 2023

Fixes #2910

Description:

Check list:

  • New tests are added (if a new feature is added)
  • New doc strings: description and/or example code are in RST format
  • Documentation is updated (if required)

Plotting learning rates:

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from ignite.handlers import create_lr_scheduler_with_warmup


def plot(warmup_end_value):
    lr = 0.2
    warm_steps = 5
    steps = 10
    warm_start = 0.023

    def get_optim():
        t1 = torch.zeros([1], requires_grad=True)
        return torch.optim.SGD([t1], lr=lr)

    def get_cos_shed():
        return CosineAnnealingWarmRestarts(optimizer, T_0=12, T_mult=3, verbose=False)

    optimizer = get_optim()
    scheduler = get_cos_shed()
    cosine_lrs = []
    for i in range(steps):
        cosine_lrs.append(optimizer.param_groups[0]["lr"])
        scheduler.step()

    optimizer = get_optim()
    scheduler = create_lr_scheduler_with_warmup(
        get_cos_shed(), warmup_start_value=warm_start, warmup_end_value=warmup_end_value, warmup_duration=warm_steps
    )

    warm_lrs = []
    real_warm_steps = warm_steps if warmup_end_value is not None else (warm_steps - 1)
    for epoch in range(real_warm_steps + steps):
        scheduler(None)
        warm_lrs.append(optimizer.param_groups[0]["lr"])

    if warmup_end_value is not None:
        plt.title("warmup_end_value != lr")
        plt.scatter(range(len(warm_lrs[:real_warm_steps])), warm_lrs[:real_warm_steps])
        plt.scatter(range(warm_steps, len(warm_lrs[real_warm_steps:]) + warm_steps), warm_lrs[real_warm_steps:])
        plt.show()
    else:
        plt.title("warmup_end_value == lr or warmup_end_value is None")
        plt.scatter(range(len(warm_lrs[:warm_steps])), warm_lrs[:warm_steps])
        plt.scatter(range(warm_steps, len(warm_lrs[warm_steps:]) + warm_steps), warm_lrs[warm_steps:])
        plt.show()


plot(None)
plot(.26)
image image

@github-actions github-actions bot added the module: handlers Core Handlers module label May 8, 2023
@vfdev-5 vfdev-5 changed the title Bug cosine scheduler Fixed parameter scheduler bug with CosineAnnealingWarmRestarts May 8, 2023
Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates @AlexanderChaptykov
I left few suggestions on how to improve the PR

assert warm_lrs[warm_steps:] == cosine_lrs
else:
assert (np.linspace(warm_start, lr, warm_steps).round(3) == np.array(warm_lrs[:warm_steps]).round(3)).all()
assert warm_lrs[warm_steps - 1 : -1] == cosine_lrs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need this, beacuse of shifting lrs if warmup_end_value == None

@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 23, 2023

Let's make the test as following:

@pytest.mark.parametrize("warmup_end_value", [0.23, None])
@pytest.mark.parametrize("T_0", [1, 12])
@pytest.mark.parametrize("T_mult", [1, 3])
def test_create_lr_scheduler_with_warmup_cosine(warmup_end_value, T_0, T_mult):
    lr = 0.2
    steps = 200
    warm_steps = 50
    warm_start = 0.023

    def get_optim():
        t1 = torch.zeros([1], requires_grad=True)
        return torch.optim.SGD([t1], lr=lr)

    def get_cos_shed():
        return CosineAnnealingWarmRestarts(optimizer, T_0=T_0, T_mult=T_mult, verbose=False)

    optimizer = get_optim()
    scheduler = get_cos_shed()
    cosine_lrs = []
    for i in range(steps):
        cosine_lrs.append(optimizer.param_groups[0]["lr"])
        scheduler.step()

    optimizer = get_optim()
    scheduler = create_lr_scheduler_with_warmup(
        get_cos_shed(), warmup_start_value=warm_start, warmup_end_value=warmup_end_value, warmup_duration=warm_steps
    )

    warm_lrs = []
    for epoch in range(warm_steps + steps):
        scheduler(None)
        warm_lrs.append(optimizer.param_groups[0]["lr"])

    if warmup_end_value is not None:
        np.testing.assert_allclose(np.linspace(warm_start, warmup_end_value, warm_steps), warm_lrs[:warm_steps])
        assert warm_lrs[warm_steps:] == cosine_lrs
    else:
        np.testing.assert_allclose(np.linspace(warm_start, lr, warm_steps), warm_lrs[:warm_steps])
        assert warm_lrs[warm_steps - 1:-1] == cosine_lrs

@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 23, 2023

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from ignite.handlers import create_lr_scheduler_with_warmup


def plot(warmup_end_value):
    lr = 0.2
    warm_steps = 5
    steps = 100
    warm_start = 0.023

    def get_optim():
        t1 = torch.zeros([1], requires_grad=True)
        return torch.optim.SGD([t1], lr=lr)

    def get_cos_shed():
        return CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, verbose=False)

    optimizer = get_optim()
    scheduler = get_cos_shed()
    cosine_lrs = []
    for i in range(steps):
        cosine_lrs.append(optimizer.param_groups[0]["lr"])
        scheduler.step()

    optimizer = get_optim()
    scheduler = create_lr_scheduler_with_warmup(
        get_cos_shed(), warmup_start_value=warm_start, warmup_end_value=warmup_end_value, warmup_duration=warm_steps
    )

    warm_lrs = []
    for epoch in range(warm_steps + steps):
        scheduler(None)
        warm_lrs.append(optimizer.param_groups[0]["lr"])

    if warmup_end_value is not None:
        plt.figure(figsize=(10, 5))
        plt.subplot(121)
        plt.title("create_lr_scheduler_with_warmup +\nCosineAnnealingWarmRestarts\nwarmup_end_value != lr")
        plt.plot(warm_lrs, "-*")
        plt.subplot(122)
        plt.title("CosineAnnealingWarmRestarts")
        plt.plot(cosine_lrs, "-*")        
        plt.show()
    else:
        plt.figure(figsize=(10, 5))
        plt.subplot(121)
        plt.title("create_lr_scheduler_with_warmup +\nCosineAnnealingWarmRestarts\nwarmup_end_value == lr")
        plt.plot(warm_lrs, "-*")
        plt.subplot(122)
        plt.title("CosineAnnealingWarmRestarts")
        plt.plot(cosine_lrs, "-*")        
        plt.show()


plot(None)
plot(.26)

image
image

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @AlexanderChaptykov for working on this issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: handlers Core Handlers module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

WarmRestarts seems not working with create_lr_scheduler_with_warmup function

2 participants