Skip to content

Commit 2e97087

Browse files
authored
Merge branch 'dev' into update_pyo3
2 parents 1679eba + 3e08f94 commit 2e97087

File tree

6 files changed

+53
-19
lines changed

6 files changed

+53
-19
lines changed

DeepFilterNet/df/checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def read_cp(
7474
epoch = get_epoch(latest)
7575
if log:
7676
logger.info("Found checkpoint {} with epoch {}".format(latest, epoch))
77-
latest = torch.load(latest, map_location="cpu")
77+
latest = torch.load(latest, map_location="cpu", weights_only=True)
7878
latest = {k.replace("clc", "df"): v for k, v in latest.items()}
7979
if blacklist:
8080
reg = re.compile("".join(f"({b})|" for b in blacklist)[:-1])

DeepFilterNet/df/enhance.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def main(args):
5252
config_allow_defaults=True,
5353
epoch=args.epoch,
5454
mask_only=args.no_df_stage,
55+
device=args.device,
5556
)
5657
suffix = suffix if args.suffix else None
5758
if args.output_dir is None:
@@ -76,7 +77,12 @@ def main(args):
7677
progress = (i + 1) / n_samples * 100
7778
t0 = time.time()
7879
audio = enhance(
79-
model, df_state, audio, pad=args.compensate_delay, atten_lim_db=args.atten_lim
80+
model,
81+
df_state,
82+
audio,
83+
pad=args.compensate_delay,
84+
atten_lim_db=args.atten_lim,
85+
device=args.device,
8086
)
8187
t1 = time.time()
8288
t_audio = audio.shape[-1] / df_sr
@@ -107,6 +113,7 @@ def init_df(
107113
epoch: Union[str, int, None] = "best",
108114
default_model: str = DEFAULT_MODEL,
109115
mask_only: bool = False,
116+
device: Optional[str] = None,
110117
) -> Tuple[nn.Module, DF, str, int]:
111118
"""Initializes and loads config, model and deep filtering state.
112119
@@ -119,6 +126,8 @@ def init_df(
119126
config_allow_defaults (bool): Whether to allow initializing new config values with defaults.
120127
epoch (str): Checkpoint epoch to load. Options are `best`, `latest`, `<int>`, and `none`.
121128
`none` disables checkpoint loading. Defaults to `best`.
129+
device (str): Set the torch compute device.
130+
If None, will automatically choose an available backend. (Optional)
122131
123132
Returns:
124133
model (nn.Modules): Intialized model, moved to GPU if available.
@@ -177,17 +186,19 @@ def init_df(
177186
logger.error("Could not find a checkpoint")
178187
exit(1)
179188
logger.debug(f"Loaded checkpoint from epoch {epoch}")
180-
model = model.to(get_device())
189+
190+
compute_device = get_device(device=device)
191+
model = model.to(compute_device)
181192
# Set suffix to model name
182193
suffix = os.path.basename(os.path.abspath(model_base_dir))
183194
if post_filter:
184195
suffix += "_pf"
185-
logger.info("Running on device {}".format(get_device()))
196+
logger.info("Running on device {}".format(compute_device))
186197
logger.info("Model loaded")
187198
return model, df_state, suffix, epoch
188199

189200

190-
def df_features(audio: Tensor, df: DF, nb_df: int, device=None) -> Tuple[Tensor, Tensor, Tensor]:
201+
def df_features(audio: Tensor, df: DF, nb_df: int, device: Optional[torch.device] = None) -> Tuple[Tensor, Tensor, Tensor]:
191202
spec = df.analysis(audio.numpy()) # [C, Tf] -> [C, Tf, F]
192203
a = get_norm_alpha(False)
193204
erb_fb = df.erb_widths()
@@ -205,7 +216,12 @@ def df_features(audio: Tensor, df: DF, nb_df: int, device=None) -> Tuple[Tensor,
205216

206217
@torch.no_grad()
207218
def enhance(
208-
model: nn.Module, df_state: DF, audio: Tensor, pad=True, atten_lim_db: Optional[float] = None
219+
model: nn.Module,
220+
df_state: DF,
221+
audio: Tensor,
222+
pad=True,
223+
atten_lim_db: Optional[float] = None,
224+
device: Optional[str] = None,
209225
):
210226
"""Enhance a single audio given a preloaded model and DF state.
211227
@@ -216,23 +232,30 @@ def enhance(
216232
pad (bool): Pad the audio to compensate for delay due to STFT/ISTFT.
217233
atten_lim_db (float): An optional noise attenuation limit in dB. E.g. an attenuation limit of
218234
12 dB only suppresses 12 dB and keeps the remaining noise in the resulting audio.
235+
device (str): Set the torch compute device.
236+
If None, will automatically choose an available backend. (Optional)
219237
220238
Returns:
221239
enhanced audio (Tensor): If `pad` was `False` of shape [C, T'] where T'<T slightly delayed due to STFT.
222240
If `pad` was `True` it has the same shape as the input.
223241
"""
242+
compute_device = get_device(device=device)
243+
model.to(compute_device)
224244
model.eval()
245+
225246
bs = audio.shape[0]
226247
if hasattr(model, "reset_h0"):
227-
model.reset_h0(batch_size=bs, device=get_device())
248+
model.reset_h0(batch_size=bs, device=compute_device)
228249
orig_len = audio.shape[-1]
229250
n_fft, hop = 0, 0
230251
if pad:
231252
n_fft, hop = df_state.fft_size(), df_state.hop_size()
232253
# Pad audio to compensate for the delay due to the real-time STFT implementation
233254
audio = F.pad(audio, (0, n_fft))
234255
nb_df = getattr(model, "nb_df", getattr(model, "df_bins", ModelParams().nb_df))
235-
spec, erb_feat, spec_feat = df_features(audio, df_state, nb_df, device=get_device())
256+
spec, erb_feat, spec_feat = df_features(
257+
audio, df_state, nb_df, device=compute_device
258+
)
236259
enhanced = model(spec.clone(), erb_feat, spec_feat)[0].cpu()
237260
enhanced = as_complex(enhanced.squeeze(1))
238261
if atten_lim_db is not None and abs(atten_lim_db) > 0:
@@ -375,6 +398,11 @@ def run():
375398
help="Don't add the model suffix to the enhanced audio files",
376399
)
377400
parser.add_argument("--no-df-stage", action="store_true")
401+
parser.add_argument(
402+
"--device",
403+
type=str,
404+
help="Set the torch compute device",
405+
)
378406
args = parser.parse_args()
379407
main(args)
380408

DeepFilterNet/df/utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,19 @@
1717
from df.model import ModelParams
1818

1919

20-
def get_device():
21-
s = config("DEVICE", default="", section="train")
22-
if s == "":
20+
def get_device(device: Optional[str] = None):
21+
s = device or config("DEVICE", default="", section="train")
22+
if not s:
2323
if torch.cuda.is_available():
24-
DEVICE = torch.device("cuda:0")
25-
else:
24+
DEVICE = torch.device("cuda")
25+
elif torch.mps.is_available():
26+
DEVICE = torch.device("mps")
27+
elif torch.xpu.is_available():
28+
DEVICE = torch.device("xpu")
29+
elif torch.cpu.is_available():
2630
DEVICE = torch.device("cpu")
31+
else:
32+
raise RuntimeError("No compute devices found")
2733
else:
2834
DEVICE = torch.device(s)
2935
return DEVICE

DeepFilterNet/pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "DeepFilterNet"
3-
version = "0.5.7-pre"
3+
version = "0.5.7"
44
description = "Noise supression using deep filtering"
55
authors = ["Hendrik Schröter"]
66
repository = "https://github.com/Rikorose/DeepFilterNet"
@@ -23,9 +23,9 @@ include = [
2323
]
2424

2525
[tool.poetry.dependencies]
26-
deepfilterlib = { path = "../pyDF/" }
27-
deepfilterdataloader = { path = "../pyDF-data/", optional = true }
28-
python = ">=3.8,<4.0"
26+
deepfilterlib = { version = "0.5.7" }
27+
deepfilterdataloader = { version = "0.5.7", optional = true }
28+
python = ">=3.8,<3.14"
2929
numpy = ">=2,<3"
3030
loguru = ">=0.5"
3131
appdirs = "^1.4"

pyDF-data/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "DeepFilterDataLoader"
3-
version = "0.5.7-pre"
3+
version = "0.5.7"
44
classifiers = ["Programming Language :: Rust"]
55
requires-python = ">=3.8"
66
dependencies = ["numpy >= 1.22"]

pyDF/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "DeepFilterLib"
3-
version = "0.5.7-pre"
3+
version = "0.5.7"
44
classifiers = ["Programming Language :: Rust"]
55
requires-python = ">=3.8"
66
dependencies = ["numpy >= 1.22"]

0 commit comments

Comments
 (0)