Skip to content

Commit 109d701

Browse files
committed
Added option to update parameters using state_dict in AveragedModel (pytorch#65495)
Summary: While implementing [EMA](pytorch/vision#4381 extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](pytorch/vision#4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation. Discussion: pytorch/vision#4406 (review) Pull Request resolved: pytorch#65495 Reviewed By: datumbox Differential Revision: D31176742 Pulled By: prabhat00155 fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2 (cherry picked from commit 2ea724b)
1 parent 6aadfda commit 109d701

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

test/test_optim.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2290,6 +2290,38 @@ def avg_fn(p_avg, p, n_avg):
22902290
for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
22912291
self.assertEqual(p_avg, p_swa)
22922292

2293+
def test_averaged_model_exponential_use_state_dict(self):
2294+
# Test AveragedModel with EMA as avg_fn and use_state_dict as True.
2295+
dnn = torch.nn.Sequential(
2296+
torch.nn.Conv2d(1, 5, kernel_size=3),
2297+
torch.nn.BatchNorm2d(5, momentum=0.3),
2298+
torch.nn.Linear(5, 10)
2299+
)
2300+
alpha = 0.9
2301+
2302+
def avg_fn(p_avg, p, n_avg):
2303+
return alpha * p_avg + (1 - alpha) * p
2304+
averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, mode='state_dict')
2305+
averaged_params = [torch.zeros_like(param) for param in dnn.state_dict().values()
2306+
if param.size() != torch.Size([])]
2307+
n_updates = 10
2308+
for i in range(n_updates):
2309+
updated_averaged_params = []
2310+
for p, p_avg in zip(dnn.state_dict().values(), averaged_params):
2311+
if p.size() == torch.Size([]):
2312+
continue
2313+
p.detach().add_(torch.randn_like(p))
2314+
if i == 0:
2315+
updated_averaged_params.append(p.clone())
2316+
else:
2317+
updated_averaged_params.append((p_avg * alpha +
2318+
p * (1 - alpha)).clone())
2319+
averaged_dnn.update_parameters(dnn)
2320+
averaged_params = updated_averaged_params
2321+
2322+
for p_avg, p_swa in zip(averaged_params, averaged_dnn.module.state_dict().values()):
2323+
self.assertEqual(p_avg, p_swa)
2324+
22932325
def _test_update_bn(self, dnn, dl_x, dl_xy, cuda):
22942326

22952327
preactivation_sum = torch.zeros(dnn.n_features)

torch/optim/swa_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class AveragedModel(Module):
2626
:class:`AveragedModel` parameter, the current value of :attr:`model`
2727
parameter and the number of models already averaged; if None,
2828
equally weighted average is used (default: None)
29+
mode (str, optional): whether to use parameters or state_dict for update
30+
(default: parameters)
2931
3032
Example:
3133
>>> loader, optimizer, model, loss_fn = ...
@@ -84,7 +86,7 @@ class AveragedModel(Module):
8486
Generalizes Well:
8587
https://arxiv.org/abs/2001.02312
8688
"""
87-
def __init__(self, model, device=None, avg_fn=None):
89+
def __init__(self, model, device=None, avg_fn=None, mode='parameters'):
8890
super(AveragedModel, self).__init__()
8991
self.module = deepcopy(model)
9092
if device is not None:
@@ -96,12 +98,15 @@ def avg_fn(averaged_model_parameter, model_parameter, num_averaged):
9698
return averaged_model_parameter + \
9799
(model_parameter - averaged_model_parameter) / (num_averaged + 1)
98100
self.avg_fn = avg_fn
101+
self.use_state_dict = mode == 'state_dict'
99102

100103
def forward(self, *args, **kwargs):
101104
return self.module(*args, **kwargs)
102105

103106
def update_parameters(self, model):
104-
for p_swa, p_model in zip(self.parameters(), model.parameters()):
107+
self_param = self.module.state_dict().values() if self.use_state_dict else self.parameters()
108+
model_param = model.state_dict().values() if self.use_state_dict else model.parameters()
109+
for p_swa, p_model in zip(self_param, model_param):
105110
device = p_swa.device
106111
p_model_ = p_model.detach().to(device)
107112
if self.n_averaged == 0:

0 commit comments

Comments
 (0)