Skip to content

Commit 8a86c46

Browse files
authored
Raise error when scripting invalid MelScale (#1505)
1 parent 9d621fd commit 8a86c46

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

test/torchaudio_unittest/transforms/torchscript_consistency_impl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ def test_AmplitudeToDB(self):
5959
spec = torch.rand((6, 201))
6060
self._assert_consistency(T.AmplitudeToDB(), spec)
6161

62+
def test_MelScale_invalid(self):
63+
with self.assertRaises(ValueError):
64+
torch.jit.script(T.MelScale())
65+
6266
def test_MelScale(self):
6367
spec_f = torch.rand((1, 201, 6))
6468
self._assert_consistency(T.MelScale(n_stft=201), spec_f)

torchaudio/transforms.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,20 @@ def __init__(self,
284284
self.mel_scale)
285285
self.register_buffer('fb', fb)
286286

287+
def __prepare_scriptable__(self):
288+
r"""If `self.fb` is empty, the `forward` method will try to resize the parameter,
289+
which does not work once the transform is scripted. However, this error does not happen
290+
until the transform is executed. This is inconvenient especially if the resulting
291+
TorchScript object is executed in other environments. Therefore, we check the
292+
validity of `self.fb` here and fail if the resulting TS does not work.
293+
294+
Returns:
295+
MelScale: self
296+
"""
297+
if self.fb.numel() == 0:
298+
raise ValueError("n_stft must be provided at construction")
299+
return self
300+
287301
def forward(self, specgram: Tensor) -> Tensor:
288302
r"""
289303
Args:

0 commit comments

Comments
 (0)