Skip to content

Commit 16f3b2f

Browse files
author
Caroline Chen
authored
Remove reuse_logits_for_grads option for RNNTL (#1610)
1 parent 25ceee7 commit 16f3b2f

File tree

9 files changed

+20
-53
lines changed

9 files changed

+20
-53
lines changed

test/torchaudio_unittest/rnnt/autograd_impl.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_RNNTLoss_gradcheck(self, data_func):
5353
data["logit_lengths"],
5454
data["target_lengths"],
5555
)
56-
loss = RNNTLoss(blank=data["blank"], reuse_logits_for_grads=False)
56+
loss = RNNTLoss(blank=data["blank"])
5757

5858
self.assert_grad(loss, inputs, enable_all_grad=False)
5959

@@ -72,7 +72,6 @@ def test_rnnt_loss_gradcheck(self, data_func):
7272
data["blank"], # blank
7373
-1, # clamp
7474
True, # fused_log_softmax
75-
False, # reuse_logits_for_grads
7675
)
7776

7877
self.assert_grad(rnnt_loss, inputs, enable_all_grad=False)

test/torchaudio_unittest/rnnt/rnnt_loss_impl.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,10 @@ def _test_costs_and_gradients(
1717
self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2
1818
):
1919
logits_shape = data["logits"].shape
20-
for reuse_logits_for_grads in [False, True]:
21-
with self.subTest(reuse_logits_for_grads=reuse_logits_for_grads):
22-
costs, gradients = compute_with_pytorch_transducer(
23-
data=data, reuse_logits_for_grads=reuse_logits_for_grads
24-
)
25-
self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol)
26-
self.assertEqual(logits_shape, gradients.shape)
27-
self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)
20+
costs, gradients = compute_with_pytorch_transducer(data=data)
21+
self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol)
22+
self.assertEqual(logits_shape, gradients.shape)
23+
self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)
2824

2925
def test_basic_backward(self):
3026
rnnt_loss = RNNTLoss()

test/torchaudio_unittest/rnnt/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,10 @@ def compute_with_numpy_transducer(data):
2323
return costs, gradients
2424

2525

26-
def compute_with_pytorch_transducer(data, reuse_logits_for_grads=False):
26+
def compute_with_pytorch_transducer(data):
2727
costs = RNNTLoss(
2828
blank=data["blank"],
2929
fused_log_softmax=data.get("fused_log_softmax", True),
30-
reuse_logits_for_grads=reuse_logits_for_grads,
3130
reduction="none",
3231
)(
3332
logits=data["logits"],

torchaudio/csrc/rnnt/autograd.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
1414
const torch::Tensor& target_lengths,
1515
int64_t blank,
1616
double clamp,
17-
bool fused_log_softmax = true,
18-
bool reuse_logits_for_grads = true) {
17+
bool fused_log_softmax = true) {
1918
torch::Tensor undef;
2019
auto result = rnnt_loss(
2120
logits,
@@ -24,8 +23,7 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
2423
target_lengths,
2524
blank,
2625
clamp,
27-
fused_log_softmax,
28-
reuse_logits_for_grads);
26+
fused_log_softmax);
2927
auto costs = std::get<0>(result);
3028
auto grads = std::get<1>(result).value_or(undef);
3129
ctx->save_for_backward({grads});
@@ -51,8 +49,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
5149
const torch::Tensor& target_lengths,
5250
int64_t blank,
5351
double clamp,
54-
bool fused_log_softmax = true,
55-
bool reuse_logits_for_grads = true) {
52+
bool fused_log_softmax = true) {
5653
at::AutoDispatchBelowADInplaceOrView guard;
5754
auto results = RNNTLossFunction::apply(
5855
logits,
@@ -61,8 +58,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
6158
target_lengths,
6259
blank,
6360
clamp,
64-
fused_log_softmax,
65-
reuse_logits_for_grads);
61+
fused_log_softmax);
6662
return std::make_tuple(results[0], results[1]);
6763
}
6864

torchaudio/csrc/rnnt/compute.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
88
const torch::Tensor& target_lengths,
99
int64_t blank,
1010
double clamp,
11-
bool fused_log_softmax = true,
12-
bool reuse_logits_for_grads = true) {
11+
bool fused_log_softmax = true) {
1312
static auto op = torch::Dispatcher::singleton()
1413
.findSchemaOrThrow("torchaudio::rnnt_loss", "")
1514
.typed<decltype(rnnt_loss)>();
@@ -20,8 +19,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
2019
target_lengths,
2120
blank,
2221
clamp,
23-
fused_log_softmax,
24-
reuse_logits_for_grads);
22+
fused_log_softmax);
2523
}
2624

2725
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
@@ -32,6 +30,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
3230
"Tensor target_lengths,"
3331
"int blank,"
3432
"float clamp,"
35-
"bool fused_log_softmax=True,"
36-
"bool reuse_logits_for_grads=True) -> (Tensor, Tensor?)");
33+
"bool fused_log_softmax=True) -> (Tensor, Tensor?)");
3734
}

torchaudio/csrc/rnnt/compute.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,4 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
99
const torch::Tensor& target_lengths,
1010
int64_t blank,
1111
double clamp,
12-
bool fused_log_softmax,
13-
bool reuse_logits_for_grads);
12+
bool fused_log_softmax);

torchaudio/csrc/rnnt/cpu/compute.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
1313
const torch::Tensor& target_lengths,
1414
int64_t blank,
1515
double clamp,
16-
bool fused_log_softmax = true,
17-
bool reuse_logits_for_grads = true) {
16+
bool fused_log_softmax = true) {
1817
TORCH_CHECK(
1918
logits.device().type() == targets.device().type(),
2019
"logits and targets must be on the same device");
@@ -92,11 +91,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
9291
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
9392
c10::optional<torch::Tensor> gradients = c10::nullopt;
9493
if (logits.requires_grad()) {
95-
if (reuse_logits_for_grads) {
96-
gradients = logits;
97-
} else {
98-
gradients = torch::zeros_like(logits);
99-
}
94+
gradients = torch::zeros_like(logits);
10095
}
10196

10297
torch::Tensor int_workspace = torch::empty(

torchaudio/csrc/rnnt/gpu/compute.cu

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
1414
const torch::Tensor& target_lengths,
1515
int64_t blank,
1616
double clamp,
17-
bool fused_log_softmax = true,
18-
bool reuse_logits_for_grads = true) {
17+
bool fused_log_softmax = true) {
1918
TORCH_CHECK(
2019
logits.device().type() == targets.device().type(),
2120
"logits and targets must be on the same device");
@@ -95,11 +94,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
9594
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
9695
c10::optional<torch::Tensor> gradients = c10::nullopt;
9796
if (logits.requires_grad()) {
98-
if (reuse_logits_for_grads) {
99-
gradients = logits;
100-
} else {
101-
gradients = torch::zeros_like(logits);
102-
}
97+
gradients = torch::zeros_like(logits);
10398
}
10499

105100
torch::Tensor int_workspace = torch::empty(

torchaudio/prototype/rnnt_loss.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ def rnnt_loss(
1515
blank: int = -1,
1616
clamp: float = -1,
1717
fused_log_softmax: bool = True,
18-
reuse_logits_for_grads: bool = True,
1918
reduction: str = "mean",
2019
):
2120
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
@@ -33,7 +32,6 @@ def rnnt_loss(
3332
blank (int, opt): blank label (Default: ``-1``)
3433
clamp (float): clamp for gradients (Default: ``-1``)
3534
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
36-
reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``)
3735
reduction (string, optional): Specifies the reduction to apply to the output:
3836
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
3937
@@ -46,9 +44,6 @@ def rnnt_loss(
4644

4745
if not fused_log_softmax:
4846
logits = torch.nn.functional.log_softmax(logits, dim=-1)
49-
reuse_logits_for_grads = (
50-
False # softmax needs the original logits value
51-
)
5247

5348
if blank < 0: # reinterpret blank index if blank < 0.
5449
blank = logits.shape[-1] + blank
@@ -60,8 +55,8 @@ def rnnt_loss(
6055
target_lengths=target_lengths,
6156
blank=blank,
6257
clamp=clamp,
63-
fused_log_softmax=fused_log_softmax,
64-
reuse_logits_for_grads=reuse_logits_for_grads,)
58+
fused_log_softmax=fused_log_softmax
59+
)
6560

6661
if reduction == 'mean':
6762
return costs.mean()
@@ -83,7 +78,6 @@ class RNNTLoss(torch.nn.Module):
8378
blank (int, opt): blank label (Default: ``-1``)
8479
clamp (float): clamp for gradients (Default: ``-1``)
8580
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
86-
reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``)
8781
reduction (string, optional): Specifies the reduction to apply to the output:
8882
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
8983
"""
@@ -93,14 +87,12 @@ def __init__(
9387
blank: int = -1,
9488
clamp: float = -1.,
9589
fused_log_softmax: bool = True,
96-
reuse_logits_for_grads: bool = True,
9790
reduction: str = "mean",
9891
):
9992
super().__init__()
10093
self.blank = blank
10194
self.clamp = clamp
10295
self.fused_log_softmax = fused_log_softmax
103-
self.reuse_logits_for_grads = reuse_logits_for_grads
10496
self.reduction = reduction
10597

10698
def forward(
@@ -129,6 +121,5 @@ def forward(
129121
self.blank,
130122
self.clamp,
131123
self.fused_log_softmax,
132-
self.reuse_logits_for_grads,
133124
self.reduction
134125
)

0 commit comments

Comments
 (0)