Skip to content

Commit 40f2a08

Browse files
authored
[BC-Breaking] Move fine-tune specific module out of wav2vec2 encoder (#1782)
Previously, the Linear module (called `readout`, which is used only for an ASR fine-tuning task) was placed in encoder module. Conceptually, the encoder has nothing to do with a module specific to fine-tuning / downstream task. The problems here are that; 1. encoder can be also used in pre-training phase, in which such a module should not present 2. The choice of Linear module is arbitral, and it is inconvenient for users to have hard-coded module structure in encoder. Therefore, this commit moves the Linear module out the encoder, and places it as `aux` attribute of `Wav2Vec2Model`. (as a result `Wav2Vec2Model` has `feature_extractor`, `encoder` and `aux` attributes.) An alternative approach is to define another module and place `Wav2Vec2Model` and aux module along each other. But that will introduce a new class we need to maintain. The expected use of `aux` is only for 1. loading the pre-trained parameters published by `fairseq` (and it's variations from HF) and 2. creating the same model architectures for comparison experiment. The newly introduced class will not be general enough for downstream adaptations, where there will be a bunch of different more complicated models. (i.e. s3prl) Therefore, based on the minimalistic approach, we put them inside of `Wav2Vec2Model`.
1 parent e9cab8f commit 40f2a08

File tree

5 files changed

+28
-25
lines changed

5 files changed

+28
-25
lines changed

test/torchaudio_unittest/models/wav2vec2/huggingface_intergration_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def _test_import_finetune(self, original, imported, config):
118118
# Readout
119119
x = torch.randn(3, 10, config["hidden_size"])
120120
ref = original.lm_head(x)
121-
hyp = imported.encoder.readout(x)
121+
hyp = imported.aux(x)
122122
self.assertEqual(ref, hyp)
123123
# The whole model without mask
124124
x = torch.randn(3, 1024)
@@ -195,8 +195,8 @@ def _test_recreate(self, imported, reloaded, config):
195195
self.assertEqual(ref, hyp)
196196
# Readout
197197
x = torch.randn(3, 10, config["hidden_size"])
198-
ref = imported.encoder.readout(x)
199-
hyp = reloaded.encoder.readout(x)
198+
ref = imported.aux(x)
199+
hyp = reloaded.aux(x)
200200
self.assertEqual(ref, hyp)
201201
# The whole model
202202
x = torch.randn(3, 1024)
@@ -208,7 +208,7 @@ def _test_recreate(self, imported, reloaded, config):
208208
def test_recreate_pretrain(self, config, factory_func):
209209
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
210210
imported = import_huggingface_model(self._get_model(config)).eval()
211-
reloaded = factory_func(num_out=imported.encoder.readout.out_features)
211+
reloaded = factory_func(num_out=imported.aux.out_features)
212212
reloaded.load_state_dict(imported.state_dict())
213213
reloaded.eval()
214214
self._test_recreate(imported, reloaded, config)
@@ -217,7 +217,7 @@ def test_recreate_pretrain(self, config, factory_func):
217217
def test_recreate_finetune(self, config, factory_func):
218218
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
219219
imported = import_huggingface_model(self._get_model(config)).eval()
220-
reloaded = factory_func(num_out=imported.encoder.readout.out_features)
220+
reloaded = factory_func(num_out=imported.aux.out_features)
221221
reloaded.load_state_dict(imported.state_dict())
222222
reloaded.eval()
223223
self._test_recreate(imported, reloaded, config)

torchaudio/models/wav2vec2/components.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -426,12 +426,10 @@ def __init__(
426426
self,
427427
feature_projection: Module,
428428
transformer: Module,
429-
readout: Module,
430429
):
431430
super().__init__()
432431
self.feature_projection = feature_projection
433432
self.transformer = transformer
434-
self.readout = readout
435433

436434
def _preprocess(
437435
self,
@@ -458,7 +456,6 @@ def forward(
458456
) -> Tensor:
459457
x, mask = self._preprocess(features, lengths)
460458
x = self.transformer(x, attention_mask=mask)
461-
x = self.readout(x)
462459
return x
463460

464461
def extract_features(
@@ -561,7 +558,6 @@ def _get_encoder(
561558
dropout: float,
562559
layer_norm_first: bool,
563560
layer_drop: float,
564-
num_out: int,
565561
) -> Encoder:
566562
"""
567563
Args:
@@ -720,8 +716,4 @@ def _get_encoder(
720716
layer_norm_first=not layer_norm_first,
721717
layer_drop=layer_drop,
722718
)
723-
readout = nn.Linear(
724-
in_features=embed_dim,
725-
out_features=num_out,
726-
)
727-
return Encoder(feature_projection, transformer, readout)
719+
return Encoder(feature_projection, transformer)

torchaudio/models/wav2vec2/model.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,20 @@ class Wav2Vec2Model(Module):
2020
encoder (torch.nn.Module):
2121
Encoder that converts the audio features into the sequence of probability
2222
distribution (in negative log-likelihood) over labels.
23+
24+
aux (torch.nn.Module or None, optional):
25+
Auxiliary module. If provided, the output from encoder is passed to this module.
2326
"""
2427
def __init__(
2528
self,
2629
feature_extractor: Module,
2730
encoder: Module,
31+
aux: Optional[Module] = None,
2832
):
2933
super().__init__()
3034
self.feature_extractor = feature_extractor
3135
self.encoder = encoder
36+
self.aux = aux
3237

3338
@torch.jit.export
3439
def extract_features(
@@ -89,7 +94,10 @@ def forward(
8994
Shape: ``(batch, )``.
9095
"""
9196
x, lengths = self.feature_extractor(waveforms, lengths)
92-
return self.encoder(x, lengths), lengths
97+
x = self.encoder(x, lengths)
98+
if self.aux is not None:
99+
x = self.aux(x)
100+
return x, lengths
93101

94102

95103
def _get_model(
@@ -108,7 +116,7 @@ def _get_model(
108116
encoder_dropout: float,
109117
encoder_layer_norm_first: bool,
110118
encoder_layer_drop: float,
111-
encoder_num_out: int,
119+
aux_num_out: int,
112120
) -> Wav2Vec2Model:
113121
if extractor_conv_layer_config is None:
114122
extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
@@ -129,9 +137,12 @@ def _get_model(
129137
dropout=encoder_dropout,
130138
layer_norm_first=encoder_layer_norm_first,
131139
layer_drop=encoder_layer_drop,
132-
num_out=encoder_num_out,
133140
)
134-
return Wav2Vec2Model(feature_extractor, encoder)
141+
aux = torch.nn.Linear(
142+
in_features=encoder_embed_dim,
143+
out_features=aux_num_out,
144+
)
145+
return Wav2Vec2Model(feature_extractor, encoder, aux)
135146

136147

137148
def wav2vec2_base(num_out: int) -> Wav2Vec2Model:
@@ -172,7 +183,7 @@ def wav2vec2_base(num_out: int) -> Wav2Vec2Model:
172183
encoder_dropout=0.1,
173184
encoder_layer_norm_first=False,
174185
encoder_layer_drop=0.1,
175-
encoder_num_out=num_out,
186+
aux_num_out=num_out,
176187
)
177188

178189

@@ -214,7 +225,7 @@ def wav2vec2_large(num_out: int) -> Wav2Vec2Model:
214225
encoder_dropout=0.1,
215226
encoder_layer_norm_first=False,
216227
encoder_layer_drop=0.1,
217-
encoder_num_out=num_out,
228+
aux_num_out=num_out,
218229
)
219230

220231

@@ -256,5 +267,5 @@ def wav2vec2_large_lv60k(num_out: int) -> Wav2Vec2Model:
256267
encoder_dropout=0.0,
257268
encoder_layer_norm_first=True,
258269
encoder_layer_drop=0.1,
259-
encoder_num_out=num_out,
270+
aux_num_out=num_out,
260271
)

torchaudio/models/wav2vec2/utils/import_fairseq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _parse_config(w2v_model, num_out):
4646
'encoder_dropout': encoder.layers[0].dropout3.p,
4747
'encoder_layer_norm_first': encoder.layer_norm_first,
4848
'encoder_layer_drop': encoder.layerdrop,
49-
'encoder_num_out': num_out,
49+
'aux_num_out': num_out,
5050
}
5151
return config
5252

@@ -110,7 +110,7 @@ def _map_key(key):
110110
match = re.match(r"proj\.(weight|bias)", key)
111111
# Encoder - Readout layer
112112
if match:
113-
return f"encoder.readout.{match.group(1)}"
113+
return f"aux.{match.group(1)}"
114114
raise ValueError(f'Unexpected key: {key_}')
115115

116116

torchaudio/models/wav2vec2/utils/import_huggingface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def _get_config(cfg):
2626
'encoder_dropout': cfg.hidden_dropout,
2727
'encoder_layer_norm_first': cfg.do_stable_layer_norm,
2828
'encoder_layer_drop': cfg.layerdrop,
29-
'encoder_num_out': cfg.vocab_size,
29+
'aux_num_out': cfg.vocab_size,
3030
}
3131
return config
3232

@@ -42,7 +42,7 @@ def _build(config, original):
4242
imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict())
4343
imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict())
4444
if original.__class__.__name__ == 'Wav2Vec2ForCTC':
45-
imported.encoder.readout.load_state_dict(original.lm_head.state_dict())
45+
imported.aux.load_state_dict(original.lm_head.state_dict())
4646
return imported
4747

4848

0 commit comments

Comments
 (0)