@@ -878,3 +878,72 @@ def forward(
878
878
output , output_length = self .filter (input = output , input_length = input_length , power = power )
879
879
880
880
return output .to (io_dtype ), output_length
881
+
882
+
883
+ class MixtureConsistencyProjection (NeuralModule ):
884
+ """Ensure estimated sources are consistent with the input mixture.
885
+ Note that the input mixture is assume to be a single-channel signal.
886
+
887
+ Args:
888
+ weighting: Optional weighting mode for the consistency constraint.
889
+ If `None`, use uniform weighting. If `power`, use the power of the
890
+ estimated source as the weight.
891
+ eps: Small positive value for regularization
892
+
893
+ Reference:
894
+ Wisdom et al., Differentiable consistency constraints for improved deep speech enhancement, 2018
895
+ """
896
+
897
+ def __init__ (self , weighting : Optional [str ] = None , eps : float = 1e-8 ):
898
+ super ().__init__ ()
899
+ self .weighting = weighting
900
+ self .eps = eps
901
+
902
+ if self .weighting not in [None , 'power' ]:
903
+ raise NotImplementedError (f'Weighting mode { self .weighting } not implemented' )
904
+
905
+ @property
906
+ def input_types (self ) -> Dict [str , NeuralType ]:
907
+ """Returns definitions of module output ports.
908
+ """
909
+ return {
910
+ "mixture" : NeuralType (('B' , 'C' , 'D' , 'T' ), SpectrogramType ()),
911
+ "estimate" : NeuralType (('B' , 'C' , 'D' , 'T' ), SpectrogramType ()),
912
+ }
913
+
914
+ @property
915
+ def output_types (self ) -> Dict [str , NeuralType ]:
916
+ """Returns definitions of module output ports.
917
+ """
918
+ return {
919
+ "output" : NeuralType (('B' , 'C' , 'D' , 'T' ), SpectrogramType ()),
920
+ }
921
+
922
+ @typecheck ()
923
+ def forward (self , mixture : torch .Tensor , estimate : torch .Tensor ) -> torch .Tensor :
924
+ """Enforce mixture consistency on the estimated sources.
925
+ Args:
926
+ mixture: Single-channel mixture, shape (B, 1, F, N)
927
+ estimate: M estimated sources, shape (B, M, F, N)
928
+
929
+ Returns:
930
+ Source estimates consistent with the mixture, shape (B, M, F, N)
931
+ """
932
+ # number of sources
933
+ M = estimate .size (- 3 )
934
+ # estimated mixture based on the estimated sources
935
+ estimated_mixture = torch .sum (estimate , dim = - 3 , keepdim = True )
936
+
937
+ # weighting
938
+ if self .weighting is None :
939
+ weight = 1 / M
940
+ elif self .weighting == 'power' :
941
+ weight = estimate .abs ().pow (2 )
942
+ weight = weight / (weight .sum (dim = - 3 , keepdim = True ) + self .eps )
943
+ else :
944
+ raise NotImplementedError (f'Weighting mode { self .weighting } not implemented' )
945
+
946
+ # consistent estimate
947
+ consistent_estimate = estimate + weight * (mixture - estimated_mixture )
948
+
949
+ return consistent_estimate
0 commit comments