Skip to content

Commit 7ee1c46

Browse files
authored
Add Kaldi Pitch feature (#1243)
1 parent 9e58e75 commit 7ee1c46

24 files changed

+1025
-46
lines changed

.circleci/unittest/linux/scripts/run_style_checks.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ fi
3838

3939
printf "\x1b[34mRunning clang-format:\x1b[0m\n"
4040
"${this_dir}"/run_clang_format.py \
41-
-r torchaudio/csrc \
41+
-r torchaudio/csrc third_party/kaldi/src \
4242
--clang-format-executable "${clangformat_path}" \
4343
&& git diff --exit-code
4444
status=$?

.gitmodules

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,7 @@
22
path = third_party/transducer/submodule
33
url = https://github.com/HawkAaron/warp-transducer
44
ignore = dirty
5+
[submodule "kaldi"]
6+
path = third_party/kaldi/submodule
7+
url = https://github.com/kaldi-asr/kaldi
8+
ignore = dirty

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ endif()
4747

4848
# Options
4949
option(BUILD_SOX "Build libsox statically" OFF)
50+
option(BUILD_KALDI "Build kaldi statically" ON)
5051
option(BUILD_TRANSDUCER "Enable transducer" OFF)
5152
option(BUILD_LIBTORCHAUDIO "Build C++ Library" ON)
5253
option(BUILD_TORCHAUDIO_PYTHON_EXTENSION "Build Python extension" OFF)

build_tools/setup_helpers/extension.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def build_extension(self, ext):
6868
'-DCMAKE_VERBOSE_MAKEFILE=ON',
6969
f"-DPython_INCLUDE_DIR={distutils.sysconfig.get_python_inc()}",
7070
f"-DBUILD_SOX:BOOL={'ON' if _BUILD_SOX else 'OFF'}",
71+
"-DBUILD_KALDI:BOOL=ON",
7172
f"-DBUILD_TRANSDUCER:BOOL={'ON' if _BUILD_TRANSDUCER else 'OFF'}",
7273
"-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON",
7374
"-DBUILD_LIBTORCHAUDIO:BOOL=OFF",
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{"sample_rate": 8000}
2+
{"sample_rate": 8000, "frames_per_chunk": 200}
3+
{"sample_rate": 8000, "frames_per_chunk": 200, "simulate_first_pass_online": true}
4+
{"sample_rate": 16000}
5+
{"sample_rate": 44100}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import subprocess
2+
3+
import torch
4+
5+
6+
def convert_args(**kwargs):
7+
args = []
8+
for key, value in kwargs.items():
9+
if key == 'sample_rate':
10+
key = 'sample_frequency'
11+
key = '--' + key.replace('_', '-')
12+
value = str(value).lower() if value in [True, False] else str(value)
13+
args.append('%s=%s' % (key, value))
14+
return args
15+
16+
17+
def run_kaldi(command, input_type, input_value):
18+
"""Run provided Kaldi command, pass a tensor and get the resulting tensor
19+
20+
Args:
21+
input_type: str
22+
'ark' or 'scp'
23+
input_value:
24+
Tensor for 'ark'
25+
string for 'scp' (path to an audio file)
26+
"""
27+
import kaldi_io
28+
29+
key = 'foo'
30+
process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
31+
if input_type == 'ark':
32+
kaldi_io.write_mat(process.stdin, input_value.cpu().numpy(), key=key)
33+
elif input_type == 'scp':
34+
process.stdin.write(f'{key} {input_value}'.encode('utf8'))
35+
else:
36+
raise NotImplementedError('Unexpected type')
37+
process.stdin.close()
38+
result = dict(kaldi_io.read_mat_ark(process.stdout))['foo']
39+
return torch.from_numpy(result.copy()) # copy supresses some torch warning

test/torchaudio_unittest/functional/batch_consistency_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,9 @@ def test_vad(self):
184184
waveform, sample_rate = torchaudio.load(filepath)
185185
self.assert_batch_consistencies(
186186
F.vad, waveform, sample_rate=sample_rate)
187+
188+
@common_utils.skipIfNoExtension
189+
def test_compute_kaldi_pitch(self):
190+
sample_rate = 44100
191+
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
192+
self.assert_batch_consistencies(F.compute_kaldi_pitch, waveform, sample_rate=sample_rate)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import torch
2+
3+
from torchaudio_unittest.common_utils import PytorchTestCase
4+
from .kaldi_compatibility_test_impl import KaldiCPUOnly
5+
6+
7+
class TestKaldiCPUOnly(KaldiCPUOnly, PytorchTestCase):
8+
dtype = torch.float32
9+
device = torch.device('cpu')
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from parameterized import parameterized
2+
import torchaudio.functional as F
3+
4+
from torchaudio_unittest.common_utils import (
5+
get_sinusoid,
6+
load_params,
7+
save_wav,
8+
skipIfNoExec,
9+
TempDirMixin,
10+
TestBaseMixin,
11+
)
12+
from torchaudio_unittest.common_utils.kaldi_utils import (
13+
convert_args,
14+
run_kaldi,
15+
)
16+
17+
18+
class KaldiCPUOnly(TempDirMixin, TestBaseMixin):
19+
def assert_equal(self, output, *, expected, rtol=None, atol=None):
20+
expected = expected.to(dtype=self.dtype, device=self.device)
21+
self.assertEqual(output, expected, rtol=rtol, atol=atol)
22+
23+
@parameterized.expand(load_params('kaldi_test_pitch_args.json'))
24+
@skipIfNoExec('compute-kaldi-pitch-feats')
25+
def test_pitch_feats(self, kwargs):
26+
"""compute_kaldi_pitch produces numerically compatible result with compute-kaldi-pitch-feats"""
27+
sample_rate = kwargs['sample_rate']
28+
waveform = get_sinusoid(dtype='float32', sample_rate=sample_rate)
29+
result = F.compute_kaldi_pitch(waveform[0], **kwargs)
30+
31+
waveform = get_sinusoid(dtype='int16', sample_rate=sample_rate)
32+
wave_file = self.get_temp_path('test.wav')
33+
save_wav(wave_file, waveform, sample_rate)
34+
35+
command = ['compute-kaldi-pitch-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-']
36+
kaldi_result = run_kaldi(command, 'scp', wave_file)
37+
self.assert_equal(result, expected=kaldi_result)

test/torchaudio_unittest/functional/torchscript_consistency_impl.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,3 +547,15 @@ def func(tensor):
547547

548548
tensor = common_utils.get_whitenoise(sample_rate=44100)
549549
self._assert_consistency(func, tensor)
550+
551+
@common_utils.skipIfNoExtension
552+
def test_compute_kaldi_pitch(self):
553+
if self.dtype != torch.float32 or self.device != torch.device('cpu'):
554+
raise unittest.SkipTest("Only float32, cpu is supported.")
555+
556+
def func(tensor):
557+
sample_rate: float = 44100.
558+
return F.compute_kaldi_pitch(tensor, sample_rate)
559+
560+
tensor = common_utils.get_whitenoise(sample_rate=44100)
561+
self._assert_consistency(func, tensor)

0 commit comments

Comments
 (0)