Skip to content

Commit c4eaf9f

Browse files
committed
Added consistency projection, addressed comments for the notebook
Signed-off-by: Ante Jukić <[email protected]>
1 parent 5080207 commit c4eaf9f

File tree

6 files changed

+1393
-1161
lines changed

6 files changed

+1393
-1161
lines changed

nemo/collections/asr/models/enhancement_models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
6060
self.mask_processor = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mask_processor)
6161
self.decoder = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.decoder)
6262

63+
if 'mixture_consistency' in self._cfg:
64+
self.mixture_consistency = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mixture_consistency)
65+
else:
66+
self.mixture_consistency = None
67+
6368
# Future enhancement:
6469
# If subclasses need to modify the config before calling super()
6570
# Check ASRBPE* classes do with their mixin
@@ -370,6 +375,10 @@ def forward(self, input_signal, input_length=None):
370375
# Mask-based processor in the encoded domain
371376
processed, processed_length = self.mask_processor(input=encoded, input_length=encoded_length, mask=mask)
372377

378+
# Mixture consistency
379+
if self.mixture_consistency is not None:
380+
processed = self.mixture_consistency(mixture=encoded, estimate=processed)
381+
373382
# Decoder
374383
processed, processed_length = self.decoder(input=processed, input_length=processed_length)
375384

nemo/collections/asr/modules/audio_modules.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,3 +878,72 @@ def forward(
878878
output, output_length = self.filter(input=output, input_length=input_length, power=power)
879879

880880
return output.to(io_dtype), output_length
881+
882+
883+
class MixtureConsistencyProjection(NeuralModule):
884+
"""Ensure estimated sources are consistent with the input mixture.
885+
Note that the input mixture is assume to be a single-channel signal.
886+
887+
Args:
888+
weighting: Optional weighting mode for the consistency constraint.
889+
If `None`, use uniform weighting. If `power`, use the power of the
890+
estimated source as the weight.
891+
eps: Small positive value for regularization
892+
893+
Reference:
894+
Wisdom et al., Differentiable consistency constraints for improved deep speech enhancement, 2018
895+
"""
896+
897+
def __init__(self, weighting: Optional[str] = None, eps: float = 1e-8):
898+
super().__init__()
899+
self.weighting = weighting
900+
self.eps = eps
901+
902+
if self.weighting not in [None, 'power']:
903+
raise NotImplementedError(f'Weighting mode {self.weighting} not implemented')
904+
905+
@property
906+
def input_types(self) -> Dict[str, NeuralType]:
907+
"""Returns definitions of module output ports.
908+
"""
909+
return {
910+
"mixture": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
911+
"estimate": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
912+
}
913+
914+
@property
915+
def output_types(self) -> Dict[str, NeuralType]:
916+
"""Returns definitions of module output ports.
917+
"""
918+
return {
919+
"output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
920+
}
921+
922+
@typecheck()
923+
def forward(self, mixture: torch.Tensor, estimate: torch.Tensor) -> torch.Tensor:
924+
"""Enforce mixture consistency on the estimated sources.
925+
Args:
926+
mixture: Single-channel mixture, shape (B, 1, F, N)
927+
estimate: M estimated sources, shape (B, M, F, N)
928+
929+
Returns:
930+
Source estimates consistent with the mixture, shape (B, M, F, N)
931+
"""
932+
# number of sources
933+
M = estimate.size(-3)
934+
# estimated mixture based on the estimated sources
935+
estimated_mixture = torch.sum(estimate, dim=-3, keepdim=True)
936+
937+
# weighting
938+
if self.weighting is None:
939+
weight = 1 / M
940+
elif self.weighting == 'power':
941+
weight = estimate.abs().pow(2)
942+
weight = weight / (weight.sum(dim=-3, keepdim=True) + self.eps)
943+
else:
944+
raise NotImplementedError(f'Weighting mode {self.weighting} not implemented')
945+
946+
# consistent estimate
947+
consistent_estimate = estimate + weight * (mixture - estimated_mixture)
948+
949+
return consistent_estimate

0 commit comments

Comments
 (0)