Skip to content

Commit 1ec3065

Browse files
committed
Added consistency projection, addressed comments for the notebook
Signed-off-by: Ante Jukić <[email protected]>
1 parent 6d95580 commit 1ec3065

File tree

3 files changed

+972
-70
lines changed

3 files changed

+972
-70
lines changed

nemo/collections/asr/models/enhancement_models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
6060
self.mask_processor = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mask_processor)
6161
self.decoder = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.decoder)
6262

63+
if 'mixture_consistency' in self._cfg:
64+
self.mixture_consistency = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mixture_consistency)
65+
else:
66+
self.mixture_consistency = None
67+
6368
# Future enhancement:
6469
# If subclasses need to modify the config before calling super()
6570
# Check ASRBPE* classes do with their mixin
@@ -370,6 +375,10 @@ def forward(self, input_signal, input_length=None):
370375
# Mask-based processor in the encoded domain
371376
processed, processed_length = self.mask_processor(input=encoded, input_length=encoded_length, mask=mask)
372377

378+
# Mixture consistency
379+
if self.mixture_consistency is not None:
380+
processed = self.mixture_consistency(mixture=encoded, estimate=processed)
381+
373382
# Decoder
374383
processed, processed_length = self.decoder(input=processed, input_length=processed_length)
375384

nemo/collections/asr/modules/audio_modules.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,3 +878,69 @@ def forward(
878878
output, output_length = self.filter(input=output, input_length=input_length, power=power)
879879

880880
return output.to(io_dtype), output_length
881+
882+
883+
class MixtureConsistencyProjection(NeuralModule):
884+
"""Ensure estimated sources are consistent with the input mixture.
885+
Note that the input mixture is assume to be a single-channel signal.
886+
887+
Args:
888+
weighting: Optional weighting mode for the consistency constraint.
889+
If `None`, use uniform weighting. If `power`, use the power of the
890+
estimated source as the weight.
891+
eps: Small positive value for regularization
892+
893+
Reference:
894+
Wisdom et al., Differentiable consistency constraints for improved deep speech enhancement, 2018
895+
"""
896+
897+
def __init__(self, weighting: Optional[str] = None, eps: float = 1e-8):
898+
super().__init__()
899+
self.weighting = weighting
900+
self.eps = eps
901+
902+
@property
903+
def input_types(self) -> Dict[str, NeuralType]:
904+
"""Returns definitions of module output ports.
905+
"""
906+
return {
907+
"mixture": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
908+
"estimate": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
909+
}
910+
911+
@property
912+
def output_types(self) -> Dict[str, NeuralType]:
913+
"""Returns definitions of module output ports.
914+
"""
915+
return {
916+
"output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
917+
}
918+
919+
@typecheck()
920+
def forward(self, mixture: torch.Tensor, estimate: torch.Tensor) -> torch.Tensor:
921+
"""Enforce mixture consistency on the estimated sources.
922+
Args:
923+
mixture: Single-channel mixture, shape (B, 1, F, N)
924+
estimate: M estimated sources, shape (B, M, F, N)
925+
926+
Returns:
927+
Source estimates consistent with the mixture, shape (B, M, F, N)
928+
"""
929+
# number of sources
930+
M = estimate.size(-3)
931+
# estimated mixture based on the estimated sources
932+
estimated_mixture = torch.sum(estimate, dim=-3, keepdim=True)
933+
934+
# weighting
935+
if self.weighting == None:
936+
weight = 1 / M
937+
elif self.weighting == 'power':
938+
weight = estimate.abs().pow(2)
939+
weight = weight / (weight.sum(dim=-3, keepdim=True) + self.eps)
940+
else:
941+
raise NotImplementedError(f'Weighting mode {self.weighting_mode} not implemented')
942+
943+
# consistent estimate
944+
consistent_estimate = estimate + weight * (mixture - estimated_mixture)
945+
946+
return consistent_estimate

0 commit comments

Comments
 (0)