diff --git a/.circleci/unittest/linux/scripts/install.sh b/.circleci/unittest/linux/scripts/install.sh index 2a9c41f073..e016bb3b63 100755 --- a/.circleci/unittest/linux/scripts/install.sh +++ b/.circleci/unittest/linux/scripts/install.sh @@ -72,7 +72,7 @@ fi ( set -x conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20' - pip install kaldi-io SoundFile coverage pytest pytest-cov 'scipy==1.7.3' transformers expecttest unidecode inflect Pillow sentencepiece pytorch-lightning 'protobuf<4.21.0' demucs tinytag + pip install kaldi-io SoundFile coverage pytest pytest-cov 'scipy==1.7.3' transformers expecttest unidecode inflect Pillow sentencepiece pytorch-lightning 'protobuf<4.21.0' demucs tinytag pyroomacoustics ) # Install fairseq git clone https://github.com/pytorch/fairseq diff --git a/.circleci/unittest/windows/scripts/install.sh b/.circleci/unittest/windows/scripts/install.sh index c1a308ec38..3d58795e69 100644 --- a/.circleci/unittest/windows/scripts/install.sh +++ b/.circleci/unittest/windows/scripts/install.sh @@ -90,7 +90,8 @@ esac unidecode \ 'protobuf<4.21.0' \ demucs \ - tinytag + tinytag \ + pyroomacoustics ) # Install fairseq git clone https://github.com/pytorch/fairseq diff --git a/CMakeLists.txt b/CMakeLists.txt index 696a736a78..74af97f897 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -58,6 +58,7 @@ endif() # Options option(BUILD_SOX "Build libsox statically" ON) option(BUILD_KALDI "Build kaldi statically" ON) +option(BUILD_RAY_TRACING "Enable ray tracing simulation" ON) # TODO: REMOVE THIS option(BUILD_RNNT "Enable RNN transducer" ON) option(BUILD_CTC_DECODER "Build Flashlight CTC decoder" ON) option(BUILD_TORCHAUDIO_PYTHON_EXTENSION "Build Python extension" OFF) diff --git a/docs/source/prototype.functional.rst b/docs/source/prototype.functional.rst index 81b2b6a197..65f27c39d9 100644 --- a/docs/source/prototype.functional.rst +++ b/docs/source/prototype.functional.rst @@ -50,3 +50,8 @@ DSP extend_pitch oscillator_bank sinc_impulse_response + +ray_tracing +~~~~~~~~~~~ + +.. autofunction:: ray_tracing diff --git a/docs/source/refs.bib b/docs/source/refs.bib index c0440ed33a..106246441a 100644 --- a/docs/source/refs.bib +++ b/docs/source/refs.bib @@ -464,6 +464,14 @@ @inproceedings{GigaSpeech2021 year=2021, author={Guoguo Chen and Shuzhou Chai and Guanbo Wang and Jiayu Du and Wei-Qiang Zhang and Chao Weng and Dan Su and Daniel Povey and Jan Trmal and Junbo Zhang and Mingjie Jin and Sanjeev Khudanpur and Shinji Watanabe and Shuaijiang Zhao and Wei Zou and Xiangang Li and Xuchen Yao and Yongqing Wang and Yujun Wang and Zhao You and Zhiyong Yan} } +@inproceedings{scheibler2018pyroomacoustics, + title={Pyroomacoustics: A python package for audio room simulation and array processing algorithms}, + author={Scheibler, Robin and Bezzam, Eric and Dokmani{\'c}, Ivan}, + booktitle={2018 IEEE international conference on acoustics, speech and signal processing (ICASSP)}, + pages={351--355}, + year={2018}, + organization={IEEE} +} @inproceedings{ko15_interspeech, author={Tom Ko and Vijayaditya Peddinti and Daniel Povey and Sanjeev Khudanpur}, title={{Audio augmentation for speech recognition}}, diff --git a/test/torchaudio_unittest/prototype/functional/functional_cpu_test.py b/test/torchaudio_unittest/prototype/functional/functional_cpu_test.py index 777430a0c3..8c4a263975 100644 --- a/test/torchaudio_unittest/prototype/functional/functional_cpu_test.py +++ b/test/torchaudio_unittest/prototype/functional/functional_cpu_test.py @@ -1,7 +1,7 @@ import torch from torchaudio_unittest.common_utils import PytorchTestCase -from .functional_test_impl import Functional64OnlyTestImpl, FunctionalTestImpl +from .functional_test_impl import Functional64OnlyTestImpl, FunctionalCPUOnlyTestImpl, FunctionalTestImpl class FunctionalFloat32CPUTest(FunctionalTestImpl, PytorchTestCase): @@ -17,3 +17,13 @@ class FunctionalFloat64CPUTest(FunctionalTestImpl, PytorchTestCase): class FunctionalFloat64OnlyCPUTest(Functional64OnlyTestImpl, PytorchTestCase): dtype = torch.float64 device = torch.device("cpu") + + +class FunctionalFloat32CPUOnlyTest(FunctionalCPUOnlyTestImpl, PytorchTestCase): + dtype = torch.float32 + device = torch.device("cpu") + + +class FunctionalFloat64CPUOnlyTest(FunctionalCPUOnlyTestImpl, PytorchTestCase): + dtype = torch.float64 + device = torch.device("cpu") diff --git a/test/torchaudio_unittest/prototype/functional/functional_test_impl.py b/test/torchaudio_unittest/prototype/functional/functional_test_impl.py index 8d6c7fd332..e01ca92d89 100644 --- a/test/torchaudio_unittest/prototype/functional/functional_test_impl.py +++ b/test/torchaudio_unittest/prototype/functional/functional_test_impl.py @@ -5,11 +5,15 @@ import torchaudio.prototype.functional as F from parameterized import param, parameterized from scipy import signal +from torchaudio._internal import module_utils as _mod_utils from torchaudio.functional import lfilter -from torchaudio_unittest.common_utils import nested_params, TestBaseMixin +from torchaudio_unittest.common_utils import nested_params, skipIfNoModule, TestBaseMixin from .dsp_utils import oscillator_bank as oscillator_bank_np, sinc_ir as sinc_ir_np +if _mod_utils.is_module_available("pyroomacoustics"): + import pyroomacoustics as pra + def _prod(l): r = 1 @@ -518,3 +522,306 @@ def _debug_plot(): except AssertionError: _debug_plot() raise + + +class FunctionalCPUOnlyTestImpl(TestBaseMixin): + @parameterized.expand( + [ + (0.1, 0.2, (2, 1, 2500)), # both float + # Per-wall + (torch.rand(4), 0.2, (2, 1, 2500)), + (0.1, torch.rand(4), (2, 1, 2500)), + (torch.rand(4), torch.rand(4), (2, 1, 2500)), + # Per-band and per-wall + (torch.rand(6, 4), 0.2, (2, 6, 2500)), + (0.1, torch.rand(6, 4), (2, 6, 2500)), + (torch.rand(6, 4), torch.rand(6, 4), (2, 6, 2500)), + ] + ) + def test_ray_tracing_output_shape(self, absorption, scattering, expected_shape): + room_dim = torch.tensor([20, 25], dtype=self.dtype) + mic_array = torch.tensor([[2, 2], [8, 8]], dtype=self.dtype) + source = torch.tensor([7, 6], dtype=self.dtype) + num_rays = 100 + + hist = F.ray_tracing( + room=room_dim, + source=source, + mic_array=mic_array, + num_rays=num_rays, + absorption=absorption, + scattering=scattering, + ) + + assert hist.shape == expected_shape + + def test_ray_tracing_input_errors(self): + with self.assertRaisesRegex(ValueError, "room must be a 1D tensor"): + F.ray_tracing( + room=torch.tensor([[4, 5]]), source=torch.tensor([0, 0]), mic_array=torch.tensor([[3, 4]]), num_rays=10 + ) + with self.assertRaisesRegex(ValueError, "room must be a 1D tensor"): + F.ray_tracing( + room=torch.tensor([4, 5, 4, 5]), + source=torch.tensor([0, 0]), + mic_array=torch.tensor([[3, 4]]), + num_rays=10, + ) + with self.assertRaisesRegex(ValueError, r"mic_array must be 1D tensor of shape \(D,\), or 2D tensor"): + F.ray_tracing( + room=torch.tensor([4, 5]), source=torch.tensor([0, 0]), mic_array=torch.tensor([[[3, 4]]]), num_rays=10 + ) + with self.assertRaisesRegex(ValueError, "room must be of float32 or float64 dtype"): + F.ray_tracing( + room=torch.tensor([4, 5]).to(torch.int), + source=torch.tensor([0, 0]), + mic_array=torch.tensor([3, 4]), + num_rays=10, + ) + with self.assertRaisesRegex(ValueError, "dtype of room, source and mic_array must be the same"): + F.ray_tracing( + room=torch.tensor([4, 5]).to(torch.float64), + source=torch.tensor([0, 0]).to(torch.float32), + mic_array=torch.tensor([3, 4]), + num_rays=10, + ) + with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"): + F.ray_tracing( + room=torch.tensor([4, 5, 10], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + ) + with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"): + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + ) + with self.assertRaisesRegex(ValueError, "Room dimension D must match with source and mic_array"): + F.ray_tracing( + room=torch.tensor([4, 5, 10], dtype=torch.float), + source=torch.tensor([0, 0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + ) + with self.assertRaisesRegex(ValueError, "time_thres=10 must be at least greater than hist_bin_size=11"): + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + time_thres=10, + hist_bin_size=11, + ) + with self.assertRaisesRegex(ValueError, "The shape of absorption must be"): + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + absorption=torch.rand(5, dtype=torch.float), + ) + with self.assertRaisesRegex(ValueError, "The shape of scattering must be"): + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + scattering=torch.rand(5, 5, dtype=torch.float), + ) + with self.assertRaisesRegex(ValueError, "The shape of absorption must be"): + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + absorption=torch.rand(5, 5, dtype=torch.float), + ) + with self.assertRaisesRegex(ValueError, "The shape of scattering must be"): + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + scattering=torch.rand(5, dtype=torch.float), + ) + with self.assertRaisesRegex( + ValueError, "absorption and scattering must have the same number of bands and walls" + ): + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + absorption=torch.rand(6, 4, dtype=torch.float), + scattering=torch.rand(5, 4, dtype=torch.float), + ) + + # Make sure passing different shapes for absorption or scattering doesn't raise an error + # float and tensor + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + absorption=0.1, + scattering=torch.rand(5, 4, dtype=torch.float), + ) + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + absorption=torch.rand(5, 4, dtype=torch.float), + scattering=0.1, + ) + # per-wall only and per-band + per-wall + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + absorption=torch.rand(4, dtype=torch.float), + scattering=torch.rand(6, 4, dtype=torch.float), + ) + F.ray_tracing( + room=torch.tensor([4, 5], dtype=torch.float), + source=torch.tensor([0, 0], dtype=torch.float), + mic_array=torch.tensor([3, 4], dtype=torch.float), + num_rays=10, + absorption=torch.rand(6, 4, dtype=torch.float), + scattering=torch.rand(4, dtype=torch.float), + ) + + def test_ray_tracing_per_band_per_wall_absorption(self): + """Check that when the value of absorption and scattering are the same + across walls and frequency bands, the output histograms are: + - all equal across frequency bands + - equal to simply passing a float value instead of a (num_bands, D) or + (D,) tensor. + """ + + room_dim = torch.tensor([20, 25], dtype=self.dtype) + mic_array = torch.tensor([[2, 2], [8, 8]], dtype=self.dtype) + source = torch.tensor([7, 6], dtype=self.dtype) + num_rays = 1_000 + ABS, SCAT = 0.1, 0.2 + + absorption = torch.full(fill_value=ABS, size=(6, 4), dtype=self.dtype) + scattering = torch.full(fill_value=SCAT, size=(6, 4), dtype=self.dtype) + hist_per_band_per_wall = F.ray_tracing( + room=room_dim, + source=source, + mic_array=mic_array, + num_rays=num_rays, + absorption=absorption, + scattering=scattering, + ) + absorption = torch.full(fill_value=ABS, size=(4,), dtype=self.dtype) + scattering = torch.full(fill_value=SCAT, size=(4,), dtype=self.dtype) + hist_per_wall = F.ray_tracing( + room=room_dim, + source=source, + mic_array=mic_array, + num_rays=num_rays, + absorption=absorption, + scattering=scattering, + ) + + absorption = ABS + scattering = SCAT + hist_single = F.ray_tracing( + room=room_dim, + source=source, + mic_array=mic_array, + num_rays=num_rays, + absorption=absorption, + scattering=scattering, + ) + assert hist_per_band_per_wall.shape == (2, 6, 2500) + assert hist_per_wall.shape == (2, 1, 2500) + assert hist_single.shape == (2, 1, 2500) + torch.testing.assert_close(hist_single, hist_per_wall) + + hist_single = hist_single.expand(2, 6, 2500) + torch.testing.assert_close(hist_single, hist_per_band_per_wall) + + @skipIfNoModule("pyroomacoustics") + @parameterized.expand( + [ + ([20, 25], [2, 2], [[8, 8], [7, 6]], 10_000), # 2D with 2 mics + ([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 1_000), # 3D with 1 mic + ] + ) + def test_ray_tracing_same_results_as_pyroomacoustics(self, room_dim, source, mic_array, num_rays): + + walls = ["west", "east", "south", "north"] + if len(room_dim) == 3: + walls += ["floor", "ceiling"] + num_walls = len(walls) + num_bands = 6 # Note: in ray tracing, we don't need to restrict the number of bands to 7 + + absorption = torch.rand(num_bands, num_walls, dtype=self.dtype) + scattering = torch.rand(num_bands, num_walls, dtype=self.dtype) + energy_thres = 1e-7 + time_thres = 10.0 + hist_bin_size = 0.004 + mic_radius = 0.5 + sound_speed = 343.0 + + room_dim = torch.tensor(room_dim, dtype=self.dtype) + source = torch.tensor(source, dtype=self.dtype) + mic_array = torch.tensor(mic_array, dtype=self.dtype) + + room = pra.ShoeBox( + room_dim.tolist(), + ray_tracing=True, + materials={ + walls[i]: pra.Material( + energy_absorption={ + "coeffs": absorption[:, i].reshape(-1).detach().numpy(), + "center_freqs": 125 * 2 ** np.arange(num_bands), + }, + scattering={ + "coeffs": scattering[:, i].reshape(-1).detach().numpy(), + "center_freqs": 125 * 2 ** np.arange(num_bands), + }, + ) + for i in range(num_walls) + }, + air_absorption=False, + max_order=0, # Make sure PRA doesn't use the hybrid method (we just want ray tracing) + ) + room.add_microphone_array(mic_array.T.tolist()) + room.add_source(source.tolist()) + room.set_ray_tracing( + n_rays=num_rays, + energy_thres=energy_thres, + time_thres=time_thres, + hist_bin_size=hist_bin_size, + receiver_radius=mic_radius, + ) + room.set_sound_speed(sound_speed) + + room.compute_rir() + hist_pra = torch.tensor(np.array(room.rt_histograms))[:, 0, 0] + + hist = F.ray_tracing( + room=room_dim, + source=source, + mic_array=mic_array, + num_rays=num_rays, + absorption=absorption, + scattering=scattering, + sound_speed=sound_speed, + mic_radius=mic_radius, + energy_thres=energy_thres, + time_thres=time_thres, + hist_bin_size=hist_bin_size, + ) + + assert hist.ndim == 3 + assert hist.shape == hist_pra.shape + self.assertEqual(hist.to(torch.float32), hist_pra) diff --git a/test/torchaudio_unittest/prototype/functional/torchscript_consistency_cpu_test.py b/test/torchaudio_unittest/prototype/functional/torchscript_consistency_cpu_test.py index 917a45bf1a..3b81309856 100644 --- a/test/torchaudio_unittest/prototype/functional/torchscript_consistency_cpu_test.py +++ b/test/torchaudio_unittest/prototype/functional/torchscript_consistency_cpu_test.py @@ -1,7 +1,7 @@ import torch from torchaudio_unittest.common_utils import PytorchTestCase -from .torchscript_consistency_test_impl import TorchScriptConsistencyTestImpl +from .torchscript_consistency_test_impl import TorchScriptConsistencyCPUOnlyTestImpl, TorchScriptConsistencyTestImpl class TorchScriptConsistencyCPUFloat32Test(TorchScriptConsistencyTestImpl, PytorchTestCase): @@ -12,3 +12,13 @@ class TorchScriptConsistencyCPUFloat32Test(TorchScriptConsistencyTestImpl, Pytor class TorchScriptConsistencyCPUFloat64Test(TorchScriptConsistencyTestImpl, PytorchTestCase): dtype = torch.float64 device = torch.device("cpu") + + +class TorchScriptConsistencyCPUOnlyFloat32Test(TorchScriptConsistencyCPUOnlyTestImpl, PytorchTestCase): + dtype = torch.float32 + device = torch.device("cpu") + + +class TorchScriptConsistencyCPUOnlyFloat64Test(TorchScriptConsistencyCPUOnlyTestImpl, PytorchTestCase): + dtype = torch.float64 + device = torch.device("cpu") diff --git a/test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py b/test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py index ced35809b8..1b47e07e1b 100644 --- a/test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py +++ b/test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py @@ -2,6 +2,7 @@ import torch import torchaudio.prototype.functional as F +from parameterized import parameterized from torchaudio_unittest.common_utils import nested_params, TestBaseMixin, torch_script @@ -98,3 +99,45 @@ def test_deemphasis(self): waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype) coeff = 0.9 self._assert_consistency(F.deemphasis, (waveform, coeff)) + + +class TorchScriptConsistencyCPUOnlyTestImpl(TorchScriptConsistencyTestImpl, TestBaseMixin): + @parameterized.expand( + [ + ([20, 25], [2, 2], [[8, 8], [7, 6]], 1_000), # 2D with 2 mics + ([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 500), # 3D with 1 mic + ] + ) + def test_ray_tracing(self, room_dim, source, mic_array, num_rays): + num_walls = 4 if len(room_dim) == 2 else 6 + num_bands = 3 + + absorption = torch.rand(num_bands, num_walls, dtype=torch.float32) + scattering = torch.rand(num_bands, num_walls, dtype=torch.float32) + + energy_thres = 1e-7 + time_thres = 10.0 + hist_bin_size = 0.004 + mic_radius = 0.5 + sound_speed = 343.0 + + room_dim = torch.tensor(room_dim, dtype=self.dtype) + source = torch.tensor(source, dtype=self.dtype) + mic_array = torch.tensor(mic_array, dtype=self.dtype) + + self._assert_consistency( + F.ray_tracing, + ( + room_dim, + source, + mic_array, + num_rays, + absorption, + scattering, + mic_radius, + sound_speed, + energy_thres, + time_thres, + hist_bin_size, + ), + ) diff --git a/torchaudio/csrc/CMakeLists.txt b/torchaudio/csrc/CMakeLists.txt index ef8e9f0a7a..b8f193bb82 100644 --- a/torchaudio/csrc/CMakeLists.txt +++ b/torchaudio/csrc/CMakeLists.txt @@ -53,6 +53,15 @@ if(BUILD_RNNT) endif() endif() +# TODO: Probably remove or edit this, cond shouldn't be BUILD_RAY_TRACING +if(BUILD_RAY_TRACING) + list( + APPEND + LIBTORCHAUDIO_SOURCES + ray_tracing.cpp + ) +endif() + if(USE_CUDA) list( APPEND diff --git a/torchaudio/csrc/ray_tracing.cpp b/torchaudio/csrc/ray_tracing.cpp new file mode 100644 index 0000000000..5291d9cc3a --- /dev/null +++ b/torchaudio/csrc/ray_tracing.cpp @@ -0,0 +1,576 @@ +/* +Copyright (c) 2014-2017 EPFL-LCAV + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +/** + * Ray tracing implementation. This is heavily based on PyRoomAcoustics: + * https://github.com/LCAV/pyroomacoustics + */ +#include +#include +#include + +namespace torchaudio { +namespace rir { +namespace { + +#define IS_HYBRID_SIM (false) // TODO: remove this once ISM method is supported +#define ISM_ORDER (10) // TODO: remove this once ISM method is supported +#define EPS ((scalar_t)(1e-5)) +#define VAL(t) (*((t).template data_ptr())) + +/** + * Wall helper class. A wall records its own absorption, reflection and + * scattering coefficient, and exposes a few methods for geometrical operations + * (e.g. reflection of a ray) + */ +template +class Wall { + public: + Wall( + const torch::Tensor _absorption, + const torch::Tensor _scattering, + const torch::Tensor _normal, + const torch::Tensor _origin) + : absorption(std::move(_absorption)), + reflection((scalar_t)1. - _absorption), + scattering(std::move(_scattering)), + normal(std::move(_normal)), + origin(std::move(_origin)) {} + + torch::Tensor get_absorption() { + return absorption; + } + torch::Tensor get_reflection() { + return reflection; + } + torch::Tensor get_scattering() { + return scattering; + } + + /** + * Returns the side (-1, 1 or 0) on which a point lies w.r.t. the wall. + */ + int side(const torch::Tensor& pos) { + auto dot = VAL((pos - origin).dot(normal)); + + if (dot > EPS) { + return 1; + } else if (dot < -EPS) { + return -1; + } else { + return 0; + } + } + + /** + * Reflects a ray (dir) on the wall. Preserves norm of vector. + */ + torch::Tensor reflect(const torch::Tensor& dir) { + return dir - normal * 2 * dir.dot(normal); + } + + /** + * Returns the cosine angle of a ray (dir) with the normal of the wall + */ + scalar_t cosine(const torch::Tensor& dir) { + return VAL(dir.dot(normal) / dir.norm()); + } + + private: + torch::Tensor absorption; + torch::Tensor reflection; // == 1 - absorption + torch::Tensor scattering; + torch::Tensor normal; // The normal to the wall: 2D or 3D vector + torch::Tensor + origin; // The origin of the wall: corresponds to an arbitrary corner. +}; + +/** + * RayTracer class helper for ray tracing. For attribute description, please see + * declarations below as well as Python wrapper. + */ +template +class RayTracer { + public: + RayTracer( + const torch::Tensor& _room, + const torch::Tensor& _source, + const torch::Tensor& _mic_array, + int _num_rays, + const torch::Tensor& _absorption, + const torch::Tensor& _scattering, + scalar_t _mic_radius, + scalar_t _sound_speed, + scalar_t _energy_thres, + scalar_t _time_thres, + scalar_t _hist_bin_size) + : room(_room), + source(_source), + mic_array(_mic_array), + num_rays(_num_rays), + energy_0(2. / num_rays), + mic_radius(_mic_radius), + mic_radius_sq(mic_radius * mic_radius), + sound_speed(_sound_speed), + energy_thres(_energy_thres), + time_thres(_time_thres), + hist_bin_size(_hist_bin_size), + max_dist(VAL(room.norm()) + (scalar_t)1.), + D(room.size(0)), + do_scattering(VAL(_scattering.max()) > (scalar_t)0.), + walls(make_walls(_absorption, _scattering)) {} + + /** + * The main (and only) public entry point of this class. The histograms Tensor + * reference is passed along and modified in the subsequent private method + * calls. This method spawns num_rays rays in all directions from the source + * and calls simul_ray() on each of them. + */ + void compute_histograms(torch::Tensor& histograms) { + // TODO: the for loop can be parallelized over num_rays by creating + // `num_threads` histograms and then sum-reducing them into a single + // histogram. + + if (D == 3) { + scalar_t offset = 2. / num_rays; + scalar_t increment = M_PI * (3. - std::sqrt(5.)); // phi increment + + for (auto i = 0; i < num_rays; ++i) { + auto z = (i * offset - 1) + offset / 2.; + auto rho = std::sqrt(1. - z * z); + + scalar_t phi = i * increment; + + auto x = cos(phi) * rho; + auto y = sin(phi) * rho; + + auto azimuth = atan2(y, x); + auto colatitude = atan2(std::sqrt(x * x + y * y), z); + + simul_ray(histograms, azimuth, colatitude); + } + } else if (D == 2) { + scalar_t offset = 2. * M_PI / num_rays; + for (int i = 0; i < num_rays; ++i) { + simul_ray(histograms, i * offset, 0.); + } + } + } + + private: + const torch::Tensor& room; + const torch::Tensor& source; + const torch::Tensor& mic_array; + int num_rays; + scalar_t energy_0; // initial energy of each ray + scalar_t mic_radius; + double mic_radius_sq; + scalar_t sound_speed; + scalar_t energy_thres; + scalar_t time_thres; + scalar_t hist_bin_size; + scalar_t max_dist; // Max distance needed to hit a wall = diagonal of room + 1 + int D; // Dimension of the room + const bool do_scattering; // Whether scattering is needed (scattering != 0) + std::vector> walls; // The walls of the room + + /** + * From a ray vector defined by its start and end, returns the next wall hit + * as a 3-tuple: + * - the hit point on the wall: 2D or 3D tensor + * - the index of the wall (as in the .walls vector attribute) + * - the distance from the start to the wall + */ + std::tuple next_wall_hit( + const torch::Tensor& start, + const torch::Tensor& end) { + const static std::vector> shoebox_orders = { + {0, 1, 2}, {1, 0, 2}, {2, 0, 1}}; + + torch::Tensor hit_point = torch::zeros_like(room); + int next_wall_index = -1; + auto hit_dist = max_dist; + + torch::Tensor dir = end - start; + + auto dir_a = dir.accessor(); + auto hit_point_a = hit_point.accessor(); + auto start_a = start.accessor(); + auto room_a = room.accessor(); + + for (auto& d : shoebox_orders) { + if (d[0] >= D) { // Happens for 2D rooms + continue; + } + auto abs_dir0 = std::abs(dir_a[d[0]]); + if (abs_dir0 < EPS) { + continue; + } + + // distance to plane + auto distance = 0.; + + // this will tell us if the front or back plane is hit + int ind_inc = 0; + + if (dir_a[d[0]] < 0.) { + hit_point[d[0]] = 0.; + distance = start_a[d[0]]; + ind_inc = 0; + } else { + hit_point_a[d[0]] = room_a[d[0]]; + distance = room_a[d[0]] - start_a[d[0]]; + ind_inc = 1; + } + + if (distance < EPS) { + continue; + } + + auto ratio = distance / abs_dir0; + + // Now compute the intersection point and verify if intersection happens + auto intersection_found = true; + for (auto i = 1; i < D; ++i) { + hit_point_a[d[i]] = start_a[d[i]] + ratio * dir_a[d[i]]; + // when there is no intersection, we jump to the next plane + if ((hit_point_a[d[i]] <= -EPS) || + (room_a[d[i]] + EPS <= hit_point_a[d[i]])) { + intersection_found = false; + break; // check next plane + } + } + + if (intersection_found) { + next_wall_index = 2 * d[0] + ind_inc; + hit_dist = VAL((hit_point - start).norm()); + break; + } + } + return std::make_tuple(hit_point, next_wall_index, hit_dist); + } + + /** + * Add energy level to the output histogram for a given microphone and a given + * time-bin (computed from travel_dist_at_mic) + */ + void log_hist( + torch::Tensor& histograms, + int mic_idx, + const torch::Tensor& energy, + scalar_t travel_dist_at_mic) { + auto time_at_mic = travel_dist_at_mic / sound_speed; + auto bin = (int)floor(time_at_mic / hist_bin_size); + histograms[mic_idx][bin] += energy; + } + + /** + * Traces a single ray. phi (horizontal) and theta (vectorical) are the angles + * of the ray from the source. Theta is 0 for 2D rooms. When a ray intersects + * a wall, it is reflected and part of its energy is absorbed. It is also + * scattered (sent directly to the microphone(s)) according to the scattering + * coefficient. When a ray is close to the microphone, its current energy is + * recoreded in the output histogram for that given time slot. + */ + void simul_ray(torch::Tensor& histograms, scalar_t phi, scalar_t theta) { + torch::Tensor start = source.clone(); + + // the direction of the ray (unit vector) + torch::Tensor dir; + if (D == 2) { + dir = torch::tensor({cos(phi), sin(phi)}, room.scalar_type()); + } else if (D == 3) { + dir = torch::tensor( + {sin(theta) * cos(phi), sin(theta) * sin(phi), cos(theta)}, + room.scalar_type()); + } + + int next_wall_index = -1; + + auto num_bands = histograms.size(2); + auto transmitted = torch::ones({num_bands}) * energy_0; + auto energy = torch::ones({num_bands}); + auto travel_dist = 0.; + + // To count the number of times the ray bounces on the walls + // For hybrid generation we add a ray to output only if specular_counter + // is higher than the ism order. + int specular_counter = 0; + + // Convert the energy threshold to transmission threshold + auto e_thres = energy_0 * energy_thres; + auto distance_thres = time_thres * sound_speed; + + torch::Tensor hit_point = torch::zeros(D, room.scalar_type()); + + while (true) { + // Find the next hit point + auto hit_distance = 0.; + std::tie(hit_point, next_wall_index, hit_distance) = + next_wall_hit(start, start + dir * max_dist); + + // If no wall is hit (rounding errors), stop the ray + if (next_wall_index == -1) { + break; + } + + auto wall = walls[next_wall_index]; + + // Check if the specular ray hits any of the microphone + if (!(IS_HYBRID_SIM && specular_counter < ISM_ORDER)) { + // Compute the distance between the line defined by (start, hit_point) + // and the center of the microphone (mic_pos) + + for (auto mic_idx = 0; mic_idx < mic_array.size(0); mic_idx++) { + torch::Tensor to_mic = mic_array[mic_idx] - start; + scalar_t impact_distance = VAL(to_mic.dot(dir)); + + bool impacts = (-EPS < impact_distance) && + (impact_distance < hit_distance + EPS); + + // If yes, we compute the ray's transmitted amplitude at the mic and + // we continue the ray + if (impacts && + (VAL((to_mic - dir * impact_distance).norm()) < + mic_radius + EPS)) { + // The length of this last hop + auto distance = std::abs(impact_distance); + + auto travel_dist_at_mic = travel_dist + distance; + double r_sq = travel_dist_at_mic * travel_dist_at_mic; + auto p_hit = + ((scalar_t)1. - + std::sqrt( + (scalar_t)1. - + mic_radius_sq / std::max(mic_radius_sq, r_sq))); + energy = transmitted / (r_sq * p_hit); + + log_hist(histograms, mic_idx, energy, travel_dist_at_mic); + } + } + } + + travel_dist += hit_distance; + transmitted *= wall.get_reflection(); + + // Let's shoot the scattered ray induced by the rebound on the wall + if (do_scattering) { + scat_ray(histograms, wall, transmitted, start, hit_point, travel_dist); + transmitted = transmitted * (1. - wall.get_scattering()); + } + + // Check if we reach the thresholds for this ray + if (travel_dist > distance_thres || VAL(transmitted.max()) < e_thres) { + break; + } + + // set up for next iteration + specular_counter = specular_counter + 1; + dir = wall.reflect(dir); // reflect w.r.t normal while conserving length + start = hit_point; + } + } + + /** + * Scatters a ray towards the microphone(s), i.e. records its scattered energy + * in the histogram. Called when a ray hits a wall. + */ + void scat_ray( + torch::Tensor& histograms, + Wall& wall, + const torch::Tensor& transmitted, + const torch::Tensor& prev_hit_point, + const torch::Tensor& hit_point, + scalar_t travel_dist) { + auto distance_thres = time_thres * sound_speed; + + for (auto mic_idx = 0; mic_idx < mic_array.size(0); mic_idx++) { + auto mic_pos = mic_array[mic_idx]; + if (wall.side(mic_pos) != wall.side(prev_hit_point)) { + continue; + } + + // As the ray is shot towards the microphone center, + // the hop dist can be easily computed + torch::Tensor hit_point_to_mic = mic_pos - hit_point; + auto hop_dist = VAL(hit_point_to_mic.norm()); + auto travel_dist_at_mic = travel_dist + hop_dist; + + // compute the scattered energy reaching the microphone + auto h_sq = hop_dist * hop_dist; + auto p_hit_equal = + (scalar_t)1. - std::sqrt((scalar_t)1. - mic_radius_sq / h_sq); + // cosine angle should be positive, but could be negative if normal is + // facing out of room so we take abs + auto p_lambert = (scalar_t)2. * std::abs(wall.cosine(hit_point_to_mic)); + auto scat_trans = + wall.get_scattering() * transmitted * p_hit_equal * p_lambert; + + if (travel_dist_at_mic < distance_thres && + VAL(scat_trans.max()) > energy_thres) { + double r_sq = double(travel_dist_at_mic) * travel_dist_at_mic; + auto p_hit = + ((scalar_t)1. - + std::sqrt( + (scalar_t)1. - mic_radius_sq / std::max(mic_radius_sq, r_sq))); + auto energy = scat_trans / (r_sq * p_hit); + log_hist(histograms, mic_idx, energy, travel_dist_at_mic); + } + } + } + + /** + * Creates the walls based on the input to the constructor. + * Since the room is always a shoebox we can hard-code values. + * Normals are vectors facing *outwards* the room, and origins are arbitrary + * corners of each wall. + */ + std::vector> make_walls( + const torch::Tensor& _absorption, + const torch::Tensor& _scattering) { + auto room_a = room.accessor(); + scalar_t zero = 0; + scalar_t W = room_a[0]; + scalar_t L = room_a[1]; + + std::vector> walls; + + torch::Tensor normals; + torch::Tensor origins; + + if (D == 2) { + normals = torch::tensor( + { + {-1, 0}, // West + {1, 0}, // East + {0, -1}, // South + {0, 1}, // North + }, + room.scalar_type()); + + origins = torch::tensor( + { + {zero, L}, // West + {W, zero}, // East + {zero, zero}, // South + {W, L}, // North + }, + room.scalar_type()); + } else { + scalar_t H = room_a[2]; + normals = torch::tensor( + { + {-1, 0, 0}, // West + {1, 0, 0}, // East + {0, -1, 0}, // South + {0, 1, 0}, // North + {0, 0, -1}, // Floor + {0, 0, 1} // Ceiling + }, + room.scalar_type()); + origins = torch::tensor( + { + {zero, L, zero}, // West + {W, zero, zero}, // East + {zero, zero, zero}, // South + {W, L, zero}, // North + {W, zero, zero}, // Floor + {W, zero, H} // Ceil + }, + room.scalar_type()); + } + + for (auto i = 0; i < normals.size(0); i++) { + walls.push_back(Wall( + _absorption.index({at::indexing::Slice(), i}), + _scattering.index({at::indexing::Slice(), i}), + normals[i], + origins[i])); + } + if (D == 2) { + // For consistency with pyroomacoustics we switch the order of the walls + // to South East North West + std::swap(walls[0], walls[3]); + std::swap(walls[0], walls[2]); + } + return walls; + } +}; + +/** + * @brief Compute energy histogram via ray tracing. See Python wrapper for + * detail about parameters and output. + */ +torch::Tensor ray_tracing( + const torch::Tensor& room, + const torch::Tensor& source, + const torch::Tensor& mic_array, + int64_t num_rays, + const torch::Tensor& absorption, + const torch::Tensor& scattering, + double mic_radius, + double sound_speed, + double energy_thres, + double time_thres, + double hist_bin_size) { + auto num_mics = mic_array.size(0); + auto num_bands = absorption.size(0); + auto num_bins = (int)ceil(time_thres / hist_bin_size); + // Output shape will actually be (num_mics, num_bands, num_bins). We let + // num_bands be the last dim during computation to make indexing cleaner in + // log_hist(), and to optimize for cache line hit. + auto histograms = + torch::zeros({num_mics, num_bins, num_bands}, room.options()); + + AT_DISPATCH_FLOATING_TYPES(room.scalar_type(), "ray_tracing", [&] { + RayTracer rt( + room, + source, + mic_array, + num_rays, + absorption, + scattering, + mic_radius, + sound_speed, + energy_thres, + time_thres, + hist_bin_size); + rt.compute_histograms(histograms); + }); + histograms = histograms.transpose( + 1, 2); // back into shape (num_mics, num_bands, num_bins) + return histograms; +} + +} // Anonymous namespace + +TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("torchaudio::ray_tracing", torchaudio::rir::ray_tracing); +} + +} // namespace rir +} // namespace torchaudio + +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def( + "torchaudio::ray_tracing(Tensor room, Tensor source, Tensor mic_array, int num_rays, Tensor absorption, Tensor scattering, float mic_radius, float sound_speed, float energy_thres, float time_thres, float hist_bin_size) -> Tensor"); +} diff --git a/torchaudio/prototype/functional/__init__.py b/torchaudio/prototype/functional/__init__.py index 599d2113af..afde914355 100644 --- a/torchaudio/prototype/functional/__init__.py +++ b/torchaudio/prototype/functional/__init__.py @@ -1,4 +1,5 @@ from ._dsp import adsr_envelope, extend_pitch, oscillator_bank, sinc_impulse_response +from ._ray_tracing import ray_tracing from .functional import add_noise, barkscale_fbanks, convolve, deemphasis, fftconvolve, preemphasis, speed __all__ = [ @@ -10,6 +11,7 @@ "extend_pitch", "fftconvolve", "oscillator_bank", + "ray_tracing", "preemphasis", "sinc_impulse_response", "speed", diff --git a/torchaudio/prototype/functional/_ray_tracing.py b/torchaudio/prototype/functional/_ray_tracing.py new file mode 100644 index 0000000000..7be689c1ae --- /dev/null +++ b/torchaudio/prototype/functional/_ray_tracing.py @@ -0,0 +1,154 @@ +from typing import Union + +import torch + + +def _validate_absorption_scattering( + v: Union[float, torch.Tensor], name: str, D: int, dtype: torch.dtype +) -> torch.Tensor: + """Validates and converts absorption or scattering parameters to a tensor with appropriate shape""" + num_walls = 4 if D == 2 else 6 + if isinstance(v, float): + out = torch.full( + size=( + 1, + num_walls, + ), + fill_value=v, + ) + elif isinstance(v, torch.Tensor) and v.ndim == 1: + if v.shape[0] != num_walls: + raise ValueError( + f"The shape of {name} must be (4,) or (6,) if it is a 1D Tensor." + f"Found the shape of room is {D} and shape of {name} is {v.shape}." + ) + out = v[None, :] + elif isinstance(v, torch.Tensor) and v.ndim == 2: + if v.shape[1] != num_walls: + raise ValueError( + f"The shape of {name} must be (num_bands, 4) for a 2D room or (num_bands, 6) " + "for a 3D room if it is a 2D Tensor. " + f"Found the shape of room is {D} and shape of {name} is {v.shape}." + ) + out = v + else: + out = v + assert out.ndim == 2 + out = out.to(dtype) + + return out + + +def ray_tracing( + room: torch.Tensor, + source: torch.Tensor, + mic_array: torch.Tensor, + num_rays: int, + absorption: Union[float, torch.Tensor] = 0.0, + scattering: Union[float, torch.Tensor] = 0.0, + mic_radius: float = 0.5, + sound_speed: float = 343.0, + energy_thres: float = 1e-7, + time_thres: float = 10.0, + hist_bin_size: float = 0.004, +) -> torch.Tensor: + r"""Compute energy histogram via ray tracing. + + The implementation is based on *pyroomacoustics* :cite:`scheibler2018pyroomacoustics`. + + ``num_rays`` rays are casted uniformly in all directions from the source; when a ray intersects a wall, + it is reflected and part of its energy is absorbed. It is also scattered (sent directly to the microphone(s)) + according to the ``scattering`` coefficient. When a ray is close to the microphone, its current energy is + recorded in the output histogram for that given time slot. + + .. devices:: CPU + + .. properties:: TorchScript + + Args: + room (torch.Tensor): The room dimensions. The shape is + `(D,)`, where ``D`` is 2 if room is a 2D room, or 3 if room is a 3D room. All rooms + are assumed to be "shoebox" rooms. + source (torch.Tensor): The coordinate of the sound source. Tensor with dimensions `(D,)`. + mic_array (torch.Tensor): The coordinate of microphone array. Tensor with dimensions `(channel, D)`. + absorption (float or torch.Tensor, optional): The absorption coefficients of wall materials. + (Default: ``0.0``). + If the dtype is ``float``, the absorption coefficient is identical to all walls and + all frequencies. + If ``absorption`` is a 1D Tensor, the shape must be `(4,)` if the room is a 2D room, + representing absorption coefficients of ``"west"``, ``"east"``, ``"south"``, and + ``"north"`` walls, respectively. + Or the shape must be `(6,)` if the room is a 3D room, representing absorption coefficients + of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``, and ``"ceiling"``, respectively. + If ``absorption`` is a 2D Tensor, the shape must be `(num_bands, 4)` if the room is a 2D room, + or `(num_bands, 6)` if the room is a 3D room. ``num_bands`` is the number of frequency bands (usually 7), + but you can choose other values. + scattering(float or torch.Tensor, optional): The scattering coefficients of wall materials. + (Default: ``0.0``). The shape and type of this parameter is the same as for ``absorption``. + mic_radius(float, optional): The radius of the microphone in meters. (Default: 0.5) + sound_speed (float, optional): The speed of sound in meters per second. (Default: ``343.0``) + energy_thres (float, optional): The energy level below which we stop tracing a ray. (Default: ``1e-7``). + The initial energy of each ray is ``2 / num_rays``. + time_thres (float, optional): The maximal duration (in seconds) for which rays are traced. (Default: 10.0) + hist_bin_size (float, optional): The size (in seconds) of each bin in the output histogram. (Default: 0.004) + Returns: + (torch.Tensor): The 3D histogram(s) where the energy of the traced ray is recorded. Each bin corresponds + to a given time slot. The shape is `(channel, num_bands, num_bins)` + where ``num_bins = ceil(time_thres / hist_bin_size)``. If both ``absorption`` and ``scattering`` + are floats, then ``num_bands == 1``. + """ + if room.ndim != 1 or room.shape[0] not in (2, 3): + raise ValueError(f"room must be a 1D tensor of shape (2,) or (3,), got shape={room.shape}") + D = room.shape[0] + + if mic_array.ndim == 1: + mic_array = mic_array[None, :] + if mic_array.ndim != 2: + raise ValueError( + "mic_array must be 1D tensor of shape (D,), or 2D tensor of shape (num_mics, D) " + f"where D is 2 or 3. Got shape = {mic_array.shape}." + ) + if room.dtype not in (torch.float32, torch.float64): + raise ValueError(f"room must be of float32 or float64 dtype, got {room.dtype} instead.") + if not (room.dtype == source.dtype == mic_array.dtype): + raise ValueError( + "dtype of room, source and mic_array must be the same. " + f"Got {room.dtype}, {source.dtype}, and {mic_array.dtype}" + ) + if not (D == source.shape[0] == mic_array.shape[1]): + raise ValueError( + "Room dimension D must match with source and mic_array. " + f"Got {D}, {source.shape[0]}, and {mic_array.shape[1]}" + ) + if time_thres < hist_bin_size: + raise ValueError(f"time_thres={time_thres} must be at least greater than hist_bin_size={hist_bin_size}.") + + absorption = _validate_absorption_scattering(absorption, name="absorption", D=D, dtype=room.dtype) + scattering = _validate_absorption_scattering(scattering, name="scattering", D=D, dtype=room.dtype) + + # Bring absorption and scattering to the same shape + if absorption.shape[0] == 1 and scattering.shape[0] > 1: + absorption = absorption.expand(scattering.shape) + if scattering.shape[0] == 1 and absorption.shape[0] > 1: + scattering = scattering.expand(absorption.shape) + if absorption.shape != scattering.shape: + raise ValueError( + "absorption and scattering must have the same number of bands and walls. " + f"Inferred shapes are {absorption.shape} and {scattering.shape}" + ) + + histograms = torch.ops.torchaudio.ray_tracing( + room, + source, + mic_array, + num_rays, + absorption, + scattering, + mic_radius, + sound_speed, + energy_thres, + time_thres, + hist_bin_size, + ) + + return histograms