Skip to content

Commit 16be98e

Browse files
committed
Refactor argument validation
1 parent 22281d3 commit 16be98e

File tree

1 file changed

+55
-44
lines changed
  • torchaudio/prototype/functional

1 file changed

+55
-44
lines changed

torchaudio/prototype/functional/_rir.py

Lines changed: 55 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -110,58 +110,68 @@ def _frac_delay(delay: torch.Tensor, delay_i: torch.Tensor, delay_filter_length:
110110
return torch.special.sinc(n - delay) * _hann(n - delay, 2 * pad)
111111

112112

113+
def _adjust_coeff(dim: int, coeffs: Union[float, torch.Tensor], name: str) -> torch.Tensor:
114+
"""Validates and converts absorption or scattering parameters to a tensor with appropriate shape
115+
116+
Args:
117+
absorption (float or torch.Tensor): The absorption coefficients of wall materials.
118+
119+
If the dtype is ``float``, the absorption coefficient is identical for all walls and
120+
all frequencies.
121+
122+
If ``absorption`` is a 1D Tensor, the shape must be `(2*dim,)`,
123+
where the values represent absorption coefficients of ``"west"``, ``"east"``,
124+
``"south"``, ``"north"``, ``"floor"``, and ``"ceiling"``, respectively.
125+
126+
If ``absorption`` is a 2D Tensor, the shape must be `(7, 2*dim)`,
127+
where 7 represents the number of octave bands.
128+
129+
Returns:
130+
(torch.Tensor): The expanded coefficient.
131+
The shape is `(1, 2*dim)` for single octave band case, and
132+
`(7, 2*dim)` for multi octave band case.
133+
"""
134+
num_walls = 2 * dim
135+
if isinstance(coeffs, float):
136+
return torch.full((1, num_walls), coeffs)
137+
if isinstance(coeffs, Tensor):
138+
if coeffs.ndim == 1:
139+
if coeffs.numel() != num_walls:
140+
raise ValueError(
141+
f"The shape of `{name}` must be ({num_walls},) when it is a 1D Tensor."
142+
f"Found the shape {coeffs.shape}."
143+
)
144+
return coeffs.unsqueeze(0)
145+
if coeffs.ndim == 2:
146+
if coeffs.shape != (7, num_walls):
147+
raise ValueError(
148+
f"The shape of `{name}` must be (7, {num_walls}) when it is a 2D Tensor."
149+
f"Found the shape {coeffs.shape}."
150+
)
151+
return coeffs
152+
raise TypeError(f"`{name}` must be float or Tensor.")
153+
154+
113155
def _validate_inputs(
114-
room: torch.Tensor, source: torch.Tensor, mic_array: torch.Tensor, absorption: Union[float, torch.Tensor]
115-
) -> torch.Tensor:
156+
dim: int,
157+
room: torch.Tensor,
158+
source: torch.Tensor,
159+
mic_array: torch.Tensor,
160+
):
116161
"""Validate dimensions of input arguments, and normalize different kinds of absorption into the same dimension.
117162
118163
Args:
119164
room (torch.Tensor): Room coordinates. The shape of `room` must be `(3,)` which represents
120165
three dimensions of the room.
121166
source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(3,)`.
122167
mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, 3)`.
123-
absorption (float or torch.Tensor): The absorption coefficients of wall materials.
124-
If the dtype is ``float``, the absorption coefficient is identical for all walls and
125-
all frequencies.
126-
If ``absorption`` is a 1D Tensor, the shape must be `(6,)`, where the values represent
127-
absorption coefficients of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``,
128-
and ``"ceiling"``, respectively.
129-
If ``absorption`` is a 2D Tensor, the shape must be `(7, 6)`, where 7 represents the number of octave bands.
130-
131-
Returns:
132-
(torch.Tensor): The absorption Tensor. The shape is `(1, 6)` for single octave band case,
133-
or `(7, 6)` for multi octave band case.
134168
"""
135-
if room.ndim != 1:
136-
raise ValueError(f"room must be a 1D Tensor. Found {room.shape}.")
137-
D = room.shape[0]
138-
if D != 3:
139-
raise ValueError(f"room must be a 3D room. Found {room.shape}.")
140-
num_wall = 6
141-
if source.shape[0] != D:
142-
raise ValueError(f"The shape of source must be `(3,)`. Found {source.shape}")
143-
if mic_array.ndim != 2:
144-
raise ValueError(f"mic_array must be a 2D Tensor. Found {mic_array.shape}.")
145-
if mic_array.shape[1] != D:
146-
raise ValueError(f"The second dimension of mic_array must be 3. Found {mic_array.shape}.")
147-
if isinstance(absorption, float):
148-
absorption = torch.ones(1, num_wall) * absorption
149-
elif isinstance(absorption, Tensor) and absorption.ndim == 1:
150-
if absorption.shape[0] != num_wall:
151-
raise ValueError(
152-
"The shape of absorption must be `(6,)` if it is a 1D Tensor." f"Found the shape {absorption.shape}."
153-
)
154-
absorption = absorption.unsqueeze(0)
155-
elif isinstance(absorption, Tensor) and absorption.ndim == 2:
156-
if absorption.shape != (7, num_wall):
157-
raise ValueError(
158-
"The shape of absorption must be `(7, 6)` if it is a 2D Tensor."
159-
f"Found the shape of room is {D} and shape of absorption is {absorption.shape}."
160-
)
161-
absorption = absorption
162-
else:
163-
absorption = absorption
164-
return absorption
169+
if not (room.ndim == 1 and room.numel() == dim):
170+
raise ValueError(f"`room` must be a 1D Tensor with {dim} elements. Found {room.shape}.")
171+
if not (source.ndim == 1 and source.numel() == dim):
172+
raise ValueError(f"`source` must be 1D Tensor with {dim} elements. Found {source.shape}.")
173+
if not (mic_array.ndim == 2 and mic_array.shape[1] == dim):
174+
raise ValueError(f"mic_array must be a 2D Tensor with shape (num_channels, {dim}). Found {mic_array.shape}.")
165175

166176

167177
def simulate_rir_ism(
@@ -220,7 +230,8 @@ def simulate_rir_ism(
220230
of octave bands are fixed to ``[125.0, 250.0, 500.0, 1000.0, 2000.0, 4000.0, 8000.0]``.
221231
Users need to tune the values of ``absorption`` to the corresponding frequencies.
222232
"""
223-
absorption = _validate_inputs(room, source, mic_array, absorption)
233+
_validate_inputs(3, room, source, mic_array)
234+
absorption = _adjust_coeff(3, absorption, "absorption")
224235
img_location, att = _compute_image_sources(room, source, max_order, absorption)
225236

226237
# compute distances between image sources and microphones

0 commit comments

Comments
 (0)