@@ -52,6 +52,7 @@ def main(args):
52
52
config_allow_defaults = True ,
53
53
epoch = args .epoch ,
54
54
mask_only = args .no_df_stage ,
55
+ device = args .device ,
55
56
)
56
57
suffix = suffix if args .suffix else None
57
58
if args .output_dir is None :
@@ -76,7 +77,12 @@ def main(args):
76
77
progress = (i + 1 ) / n_samples * 100
77
78
t0 = time .time ()
78
79
audio = enhance (
79
- model , df_state , audio , pad = args .compensate_delay , atten_lim_db = args .atten_lim
80
+ model ,
81
+ df_state ,
82
+ audio ,
83
+ pad = args .compensate_delay ,
84
+ atten_lim_db = args .atten_lim ,
85
+ device = args .device ,
80
86
)
81
87
t1 = time .time ()
82
88
t_audio = audio .shape [- 1 ] / df_sr
@@ -107,6 +113,7 @@ def init_df(
107
113
epoch : Union [str , int , None ] = "best" ,
108
114
default_model : str = DEFAULT_MODEL ,
109
115
mask_only : bool = False ,
116
+ device : Optional [str ] = None ,
110
117
) -> Tuple [nn .Module , DF , str , int ]:
111
118
"""Initializes and loads config, model and deep filtering state.
112
119
@@ -119,6 +126,8 @@ def init_df(
119
126
config_allow_defaults (bool): Whether to allow initializing new config values with defaults.
120
127
epoch (str): Checkpoint epoch to load. Options are `best`, `latest`, `<int>`, and `none`.
121
128
`none` disables checkpoint loading. Defaults to `best`.
129
+ device (str): Set the torch compute device.
130
+ If None, will automatically choose an available backend. (Optional)
122
131
123
132
Returns:
124
133
model (nn.Modules): Intialized model, moved to GPU if available.
@@ -177,17 +186,19 @@ def init_df(
177
186
logger .error ("Could not find a checkpoint" )
178
187
exit (1 )
179
188
logger .debug (f"Loaded checkpoint from epoch { epoch } " )
180
- model = model .to (get_device ())
189
+
190
+ compute_device = get_device (device = device )
191
+ model = model .to (compute_device )
181
192
# Set suffix to model name
182
193
suffix = os .path .basename (os .path .abspath (model_base_dir ))
183
194
if post_filter :
184
195
suffix += "_pf"
185
- logger .info ("Running on device {}" .format (get_device () ))
196
+ logger .info ("Running on device {}" .format (compute_device ))
186
197
logger .info ("Model loaded" )
187
198
return model , df_state , suffix , epoch
188
199
189
200
190
- def df_features (audio : Tensor , df : DF , nb_df : int , device = None ) -> Tuple [Tensor , Tensor , Tensor ]:
201
+ def df_features (audio : Tensor , df : DF , nb_df : int , device : Optional [ torch . device ] = None ) -> Tuple [Tensor , Tensor , Tensor ]:
191
202
spec = df .analysis (audio .numpy ()) # [C, Tf] -> [C, Tf, F]
192
203
a = get_norm_alpha (False )
193
204
erb_fb = df .erb_widths ()
@@ -205,7 +216,12 @@ def df_features(audio: Tensor, df: DF, nb_df: int, device=None) -> Tuple[Tensor,
205
216
206
217
@torch .no_grad ()
207
218
def enhance (
208
- model : nn .Module , df_state : DF , audio : Tensor , pad = True , atten_lim_db : Optional [float ] = None
219
+ model : nn .Module ,
220
+ df_state : DF ,
221
+ audio : Tensor ,
222
+ pad = True ,
223
+ atten_lim_db : Optional [float ] = None ,
224
+ device : Optional [str ] = None ,
209
225
):
210
226
"""Enhance a single audio given a preloaded model and DF state.
211
227
@@ -216,23 +232,30 @@ def enhance(
216
232
pad (bool): Pad the audio to compensate for delay due to STFT/ISTFT.
217
233
atten_lim_db (float): An optional noise attenuation limit in dB. E.g. an attenuation limit of
218
234
12 dB only suppresses 12 dB and keeps the remaining noise in the resulting audio.
235
+ device (str): Set the torch compute device.
236
+ If None, will automatically choose an available backend. (Optional)
219
237
220
238
Returns:
221
239
enhanced audio (Tensor): If `pad` was `False` of shape [C, T'] where T'<T slightly delayed due to STFT.
222
240
If `pad` was `True` it has the same shape as the input.
223
241
"""
242
+ compute_device = get_device (device = device )
243
+ model .to (compute_device )
224
244
model .eval ()
245
+
225
246
bs = audio .shape [0 ]
226
247
if hasattr (model , "reset_h0" ):
227
- model .reset_h0 (batch_size = bs , device = get_device () )
248
+ model .reset_h0 (batch_size = bs , device = compute_device )
228
249
orig_len = audio .shape [- 1 ]
229
250
n_fft , hop = 0 , 0
230
251
if pad :
231
252
n_fft , hop = df_state .fft_size (), df_state .hop_size ()
232
253
# Pad audio to compensate for the delay due to the real-time STFT implementation
233
254
audio = F .pad (audio , (0 , n_fft ))
234
255
nb_df = getattr (model , "nb_df" , getattr (model , "df_bins" , ModelParams ().nb_df ))
235
- spec , erb_feat , spec_feat = df_features (audio , df_state , nb_df , device = get_device ())
256
+ spec , erb_feat , spec_feat = df_features (
257
+ audio , df_state , nb_df , device = compute_device
258
+ )
236
259
enhanced = model (spec .clone (), erb_feat , spec_feat )[0 ].cpu ()
237
260
enhanced = as_complex (enhanced .squeeze (1 ))
238
261
if atten_lim_db is not None and abs (atten_lim_db ) > 0 :
@@ -375,6 +398,11 @@ def run():
375
398
help = "Don't add the model suffix to the enhanced audio files" ,
376
399
)
377
400
parser .add_argument ("--no-df-stage" , action = "store_true" )
401
+ parser .add_argument (
402
+ "--device" ,
403
+ type = str ,
404
+ help = "Set the torch compute device" ,
405
+ )
378
406
args = parser .parse_args ()
379
407
main (args )
380
408
0 commit comments