@@ -878,3 +878,69 @@ def forward(
878
878
output , output_length = self .filter (input = output , input_length = input_length , power = power )
879
879
880
880
return output .to (io_dtype ), output_length
881
+
882
+
883
+ class MixtureConsistencyProjection (NeuralModule ):
884
+ """Ensure estimated sources are consistent with the input mixture.
885
+ Note that the input mixture is assume to be a single-channel signal.
886
+
887
+ Args:
888
+ weighting: Optional weighting mode for the consistency constraint.
889
+ If `None`, use uniform weighting. If `power`, use the power of the
890
+ estimated source as the weight.
891
+ eps: Small positive value for regularization
892
+
893
+ Reference:
894
+ Wisdom et al., Differentiable consistency constraints for improved deep speech enhancement, 2018
895
+ """
896
+
897
+ def __init__ (self , weighting : Optional [str ] = None , eps : float = 1e-8 ):
898
+ super ().__init__ ()
899
+ self .weighting = weighting
900
+ self .eps = eps
901
+
902
+ @property
903
+ def input_types (self ) -> Dict [str , NeuralType ]:
904
+ """Returns definitions of module output ports.
905
+ """
906
+ return {
907
+ "mixture" : NeuralType (('B' , 'C' , 'D' , 'T' ), SpectrogramType ()),
908
+ "estimate" : NeuralType (('B' , 'C' , 'D' , 'T' ), SpectrogramType ()),
909
+ }
910
+
911
+ @property
912
+ def output_types (self ) -> Dict [str , NeuralType ]:
913
+ """Returns definitions of module output ports.
914
+ """
915
+ return {
916
+ "output" : NeuralType (('B' , 'C' , 'D' , 'T' ), SpectrogramType ()),
917
+ }
918
+
919
+ @typecheck ()
920
+ def forward (self , mixture : torch .Tensor , estimate : torch .Tensor ) -> torch .Tensor :
921
+ """Enforce mixture consistency on the estimated sources.
922
+ Args:
923
+ mixture: Single-channel mixture, shape (B, 1, F, N)
924
+ estimate: M estimated sources, shape (B, M, F, N)
925
+
926
+ Returns:
927
+ Source estimates consistent with the mixture, shape (B, M, F, N)
928
+ """
929
+ # number of sources
930
+ M = estimate .size (- 3 )
931
+ # estimated mixture based on the estimated sources
932
+ estimated_mixture = torch .sum (estimate , dim = - 3 , keepdim = True )
933
+
934
+ # weighting
935
+ if self .weighting == None :
936
+ weight = 1 / M
937
+ elif self .weighting == 'power' :
938
+ weight = estimate .abs ().pow (2 )
939
+ weight = weight / (weight .sum (dim = - 3 , keepdim = True ) + self .eps )
940
+ else :
941
+ raise NotImplementedError (f'Weighting mode { self .weighting_mode } not implemented' )
942
+
943
+ # consistent estimate
944
+ consistent_estimate = estimate + weight * (mixture - estimated_mixture )
945
+
946
+ return consistent_estimate
0 commit comments