Skip to content

Commit dbc1621

Browse files
authored
A quickstart speech enhancement tutorial (NVIDIA-NeMo#6492)
A simple example of training a model for speech enhancement task Signed-off-by: Ante Jukić <[email protected]>
1 parent 8f815ff commit dbc1621

File tree

11 files changed

+1576
-29
lines changed

11 files changed

+1576
-29
lines changed

examples/asr/audio_to_audio/conf/multichannel_enhancement.yaml renamed to examples/audio_tasks/conf/beamforming.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# This configuration contains the default values for training a multichannel speech enhancement model.
1+
# This configuration contains the exemplary values for training a multichannel speech enhancement model with a mask-based beamformer.
22
#
3-
name: "multichannel_enhancement"
3+
name: "beamforming"
44

55
model:
66
sample_rate: 16000
@@ -78,10 +78,10 @@ model:
7878

7979
optim:
8080
name: adamw
81-
lr: 1e-3
81+
lr: 1e-4
8282
# optimizer arguments
8383
betas: [0.9, 0.98]
84-
weight_decay: 0
84+
weight_decay: 1e-3
8585

8686
trainer:
8787
devices: -1 # number of GPUs, -1 would use all available GPUs
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# This configuration contains the exemplary values for training a multichannel speech enhancement model with a mask-based beamformer.
2+
#
3+
name: "masking"
4+
5+
model:
6+
sample_rate: 16000
7+
skip_nan_grad: false
8+
num_outputs: 1
9+
10+
train_ds:
11+
manifest_filepath: ???
12+
input_key: audio_filepath # key of the input signal path in the manifest
13+
target_key: target_filepath # key of the target signal path in the manifest
14+
target_channel_selector: 0 # target signal is the first channel from files in target_key
15+
audio_duration: 4.0 # in seconds, audio segment duration for training
16+
random_offset: true # if the file is longer than audio_duration, use random offset to select a subsegment
17+
min_duration: ${model.train_ds.audio_duration}
18+
batch_size: 64 # batch size may be increased based on the available memory
19+
shuffle: true
20+
num_workers: 8
21+
pin_memory: true
22+
23+
validation_ds:
24+
manifest_filepath: ???
25+
input_key: audio_filepath # key of the input signal path in the manifest
26+
target_key: target_filepath
27+
target_channel_selector: 0 # target signal is the first channel from files in target_key
28+
batch_size: 64 # batch size may be increased based on the available memory
29+
shuffle: false
30+
num_workers: 4
31+
pin_memory: true
32+
33+
test_ds:
34+
manifest_filepath: ???
35+
input_key: audio_filepath # key of the input signal path in the manifest
36+
target_key: target_filepath # key of the target signal path in the manifest
37+
target_channel_selector: 0 # target signal is the first channel from files in target_key
38+
batch_size: 1 # batch size may be increased based on the available memory
39+
shuffle: false
40+
num_workers: 4
41+
pin_memory: true
42+
43+
encoder:
44+
_target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram
45+
fft_length: 512 # Length of the window and FFT for calculating spectrogram
46+
hop_length: 256 # Hop length for calculating spectrogram
47+
power: null
48+
49+
decoder:
50+
_target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio
51+
fft_length: 512 # Length of the window and FFT for calculating spectrogram
52+
hop_length: 256 # Hop length for calculating spectrogram
53+
54+
mask_estimator:
55+
_target_: nemo.collections.asr.modules.audio_modules.MaskEstimatorRNN
56+
num_outputs: ${model.num_outputs}
57+
num_subbands: 257 # Number of subbands of the input spectrogram
58+
num_features: 256 # Number of features at RNN input
59+
num_layers: 5 # Number of RNN layers
60+
bidirectional: true # Use bi-directional RNN
61+
62+
mask_processor:
63+
_target_: nemo.collections.asr.modules.audio_modules.MaskReferenceChannel # Apply mask on the reference channel
64+
ref_channel: 0 # Reference channel for the output
65+
66+
loss:
67+
_target_: nemo.collections.asr.losses.SDRLoss
68+
scale_invariant: true # Use scale-invariant SDR
69+
70+
metrics:
71+
val:
72+
sdr: # output SDR
73+
_target_: torchmetrics.audio.SignalDistortionRatio
74+
test:
75+
sdr_ch0: # SDR on output channel 0
76+
_target_: torchmetrics.audio.SignalDistortionRatio
77+
channel: 0
78+
79+
optim:
80+
name: adamw
81+
lr: 1e-4
82+
# optimizer arguments
83+
betas: [0.9, 0.98]
84+
weight_decay: 1e-3
85+
86+
trainer:
87+
devices: -1 # number of GPUs, -1 would use all available GPUs
88+
num_nodes: 1
89+
max_epochs: -1
90+
max_steps: -1 # computed at runtime if not set
91+
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
92+
accelerator: auto
93+
strategy: ddp
94+
accumulate_grad_batches: 1
95+
gradient_clip_val: null
96+
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
97+
log_every_n_steps: 25 # Interval of logging.
98+
enable_progress_bar: true
99+
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
100+
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
101+
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
102+
sync_batchnorm: true
103+
enable_checkpointing: False # Provided by exp_manager
104+
logger: false # Provided by exp_manager
105+
106+
exp_manager:
107+
exp_dir: null
108+
name: ${name}
109+
create_tensorboard_logger: true
110+
create_checkpoint_callback: true
111+
checkpoint_callback_params:
112+
# in case of multiple validation sets, first one is used
113+
monitor: "val_loss"
114+
mode: "min"
115+
save_top_k: 5
116+
always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints
117+
118+
# you need to set these two to true to continue the training
119+
resume_if_exists: false
120+
resume_ignore_no_checkpoint: false
121+
122+
# You may use this section to create a W&B logger
123+
create_wandb_logger: false
124+
wandb_logger_kwargs:
125+
name: null
126+
project: null

examples/asr/audio_to_audio/speech_enhancement.py renamed to examples/audio_tasks/speech_enhancement.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# Training the model
1717
1818
Basic run (on CPU for 50 epochs):
19-
python examples/asr/experimental/audio_to_audio/speech_enhancement.py \
19+
python examples/audio_tasks/speech_enhancement.py \
2020
# (Optional: --config-path=<path to dir of configs> --config-name=<name of config without .yaml>) \
2121
model.train_ds.manifest_filepath="<path to manifest file>" \
2222
model.validation_ds.manifest_filepath="<path to manifest file>" \
@@ -36,7 +36,7 @@
3636
from nemo.utils.exp_manager import exp_manager
3737

3838

39-
@hydra_runner(config_path="./conf", config_name="multichannel_enhancement")
39+
@hydra_runner(config_path="./conf", config_name="masking")
4040
def main(cfg):
4141
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg, resolve=True)}')
4242

nemo/collections/asr/data/audio_to_audio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def get_samples_synchronized(
462462

463463
if duration + fixed_offset > min_audio_duration:
464464
# The shortest file is shorter than the requested duration
465-
logging.warning(
465+
logging.debug(
466466
f'Shortest file ({min_audio_duration}s) is less than the desired duration {duration}s + fixed offset {fixed_offset}s. Returned signals will be shortened to {available_duration} seconds.'
467467
)
468468
offset = fixed_offset

nemo/collections/asr/models/enhancement_models.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
6060
self.mask_processor = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mask_processor)
6161
self.decoder = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.decoder)
6262

63+
if 'mixture_consistency' in self._cfg:
64+
self.mixture_consistency = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mixture_consistency)
65+
else:
66+
self.mixture_consistency = None
67+
6368
# Future enhancement:
6469
# If subclasses need to modify the config before calling super()
6570
# Check ASRBPE* classes do with their mixin
@@ -316,7 +321,7 @@ def input_types(self) -> Dict[str, NeuralType]:
316321
"input_signal": NeuralType(
317322
('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)
318323
), # multi-channel format, channel dimension can be 1 for single-channel audio
319-
"input_length": NeuralType(tuple('B'), LengthsType()),
324+
"input_length": NeuralType(tuple('B'), LengthsType(), optional=True),
320325
}
321326

322327
@property
@@ -325,7 +330,7 @@ def output_types(self) -> Dict[str, NeuralType]:
325330
"output_signal": NeuralType(
326331
('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)
327332
), # multi-channel format, channel dimension can be 1 for single-channel audio
328-
"output_length": NeuralType(tuple('B'), LengthsType()),
333+
"output_length": NeuralType(tuple('B'), LengthsType(), optional=True),
329334
}
330335

331336
def match_batch_length(self, input: torch.Tensor, batch_length: int):
@@ -346,7 +351,7 @@ def match_batch_length(self, input: torch.Tensor, batch_length: int):
346351
return torch.nn.functional.pad(input, pad, 'constant', 0)
347352

348353
@typecheck()
349-
def forward(self, input_signal, input_length):
354+
def forward(self, input_signal, input_length=None):
350355
"""
351356
Forward pass of the model.
352357
@@ -370,6 +375,10 @@ def forward(self, input_signal, input_length):
370375
# Mask-based processor in the encoded domain
371376
processed, processed_length = self.mask_processor(input=encoded, input_length=encoded_length, mask=mask)
372377

378+
# Mixture consistency
379+
if self.mixture_consistency is not None:
380+
processed = self.mixture_consistency(mixture=encoded, estimate=processed)
381+
373382
# Decoder
374383
processed, processed_length = self.decoder(input=processed, input_length=processed_length)
375384

nemo/collections/asr/modules/audio_modules.py

Lines changed: 88 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -251,14 +251,14 @@ def __init__(
251251
else:
252252
raise ValueError(f'Unknown rnn_type: {rnn_type}')
253253

254+
self.fc = torch.nn.Linear(
255+
in_features=2 * num_features if bidirectional else num_features, out_features=num_features
256+
)
257+
self.norm = torch.nn.LayerNorm(num_features)
258+
254259
# Each output shares the RNN and has a separate projection
255260
self.output_projections = torch.nn.ModuleList(
256-
[
257-
torch.nn.Linear(
258-
in_features=2 * num_features if bidirectional else num_features, out_features=num_subbands
259-
)
260-
for _ in range(num_outputs)
261-
]
261+
[torch.nn.Linear(in_features=num_features, out_features=num_subbands) for _ in range(num_outputs)]
262262
)
263263
self.output_nonlinearity = torch.nn.Sigmoid()
264264

@@ -310,33 +310,36 @@ def forward(self, input: torch.Tensor, input_length: torch.Tensor) -> Tuple[torc
310310
).to(input.device)
311311
self.rnn.flatten_parameters()
312312
input_packed, _ = self.rnn(input_packed)
313-
input, input_length = torch.nn.utils.rnn.pad_packed_sequence(input_packed, batch_first=True)
314-
input_length = input_length.to(input.device)
313+
output, output_length = torch.nn.utils.rnn.pad_packed_sequence(input_packed, batch_first=True)
314+
output_length = output_length.to(input.device)
315+
316+
# Layer normalization and skip connection
317+
output = self.norm(self.fc(output)) + input
315318

316319
# Create `num_outputs` masks
317-
output = []
320+
masks = []
318321
for output_projection in self.output_projections:
319322
# Output projection
320-
mask = output_projection(input)
323+
mask = output_projection(output)
321324
mask = self.output_nonlinearity(mask)
322325

323326
# Back to the original format
324327
# (B, N, F) -> (B, F, N)
325328
mask = mask.transpose(2, 1)
326329

327330
# Append to the output
328-
output.append(mask)
331+
masks.append(mask)
329332

330333
# Stack along channel dimension to get (B, M, F, N)
331-
output = torch.stack(output, axis=1)
334+
masks = torch.stack(masks, axis=1)
332335

333-
# Mask frames beyond input length
336+
# Mask frames beyond output length
334337
length_mask: torch.Tensor = make_seq_mask_like(
335-
lengths=input_length, like=output, time_dim=-1, valid_ones=False
338+
lengths=output_length, like=masks, time_dim=-1, valid_ones=False
336339
)
337-
output = output.masked_fill(length_mask, 0.0)
340+
masks = masks.masked_fill(length_mask, 0.0)
338341

339-
return output, input_length
342+
return masks, output_length
340343

341344

342345
class MaskReferenceChannel(NeuralModule):
@@ -875,3 +878,72 @@ def forward(
875878
output, output_length = self.filter(input=output, input_length=input_length, power=power)
876879

877880
return output.to(io_dtype), output_length
881+
882+
883+
class MixtureConsistencyProjection(NeuralModule):
884+
"""Ensure estimated sources are consistent with the input mixture.
885+
Note that the input mixture is assume to be a single-channel signal.
886+
887+
Args:
888+
weighting: Optional weighting mode for the consistency constraint.
889+
If `None`, use uniform weighting. If `power`, use the power of the
890+
estimated source as the weight.
891+
eps: Small positive value for regularization
892+
893+
Reference:
894+
Wisdom et al., Differentiable consistency constraints for improved deep speech enhancement, 2018
895+
"""
896+
897+
def __init__(self, weighting: Optional[str] = None, eps: float = 1e-8):
898+
super().__init__()
899+
self.weighting = weighting
900+
self.eps = eps
901+
902+
if self.weighting not in [None, 'power']:
903+
raise NotImplementedError(f'Weighting mode {self.weighting} not implemented')
904+
905+
@property
906+
def input_types(self) -> Dict[str, NeuralType]:
907+
"""Returns definitions of module output ports.
908+
"""
909+
return {
910+
"mixture": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
911+
"estimate": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
912+
}
913+
914+
@property
915+
def output_types(self) -> Dict[str, NeuralType]:
916+
"""Returns definitions of module output ports.
917+
"""
918+
return {
919+
"output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
920+
}
921+
922+
@typecheck()
923+
def forward(self, mixture: torch.Tensor, estimate: torch.Tensor) -> torch.Tensor:
924+
"""Enforce mixture consistency on the estimated sources.
925+
Args:
926+
mixture: Single-channel mixture, shape (B, 1, F, N)
927+
estimate: M estimated sources, shape (B, M, F, N)
928+
929+
Returns:
930+
Source estimates consistent with the mixture, shape (B, M, F, N)
931+
"""
932+
# number of sources
933+
M = estimate.size(-3)
934+
# estimated mixture based on the estimated sources
935+
estimated_mixture = torch.sum(estimate, dim=-3, keepdim=True)
936+
937+
# weighting
938+
if self.weighting is None:
939+
weight = 1 / M
940+
elif self.weighting == 'power':
941+
weight = estimate.abs().pow(2)
942+
weight = weight / (weight.sum(dim=-3, keepdim=True) + self.eps)
943+
else:
944+
raise NotImplementedError(f'Weighting mode {self.weighting} not implemented')
945+
946+
# consistent estimate
947+
consistent_estimate = estimate + weight * (mixture - estimated_mixture)
948+
949+
return consistent_estimate

0 commit comments

Comments
 (0)