Skip to content

Remove reuse_logits_for_grads option for RNNTL #1610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions test/torchaudio_unittest/rnnt/autograd_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_RNNTLoss_gradcheck(self, data_func):
data["logit_lengths"],
data["target_lengths"],
)
loss = RNNTLoss(blank=data["blank"], reuse_logits_for_grads=False)
loss = RNNTLoss(blank=data["blank"])

self.assert_grad(loss, inputs, enable_all_grad=False)

Expand All @@ -72,7 +72,6 @@ def test_rnnt_loss_gradcheck(self, data_func):
data["blank"], # blank
-1, # clamp
True, # fused_log_softmax
False, # reuse_logits_for_grads
)

self.assert_grad(rnnt_loss, inputs, enable_all_grad=False)
Expand Down
12 changes: 4 additions & 8 deletions test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,10 @@ def _test_costs_and_gradients(
self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2
):
logits_shape = data["logits"].shape
for reuse_logits_for_grads in [False, True]:
with self.subTest(reuse_logits_for_grads=reuse_logits_for_grads):
costs, gradients = compute_with_pytorch_transducer(
data=data, reuse_logits_for_grads=reuse_logits_for_grads
)
self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(logits_shape, gradients.shape)
self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)
costs, gradients = compute_with_pytorch_transducer(data=data)
self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(logits_shape, gradients.shape)
self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)

def test_basic_backward(self):
rnnt_loss = RNNTLoss()
Expand Down
3 changes: 1 addition & 2 deletions test/torchaudio_unittest/rnnt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ def compute_with_numpy_transducer(data):
return costs, gradients


def compute_with_pytorch_transducer(data, reuse_logits_for_grads=False):
def compute_with_pytorch_transducer(data):
costs = RNNTLoss(
blank=data["blank"],
fused_log_softmax=data.get("fused_log_softmax", True),
reuse_logits_for_grads=reuse_logits_for_grads,
reduction="none",
)(
logits=data["logits"],
Expand Down
12 changes: 4 additions & 8 deletions torchaudio/csrc/rnnt/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true,
bool reuse_logits_for_grads = true) {
bool fused_log_softmax = true) {
torch::Tensor undef;
auto result = rnnt_loss(
logits,
Expand All @@ -24,8 +23,7 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
target_lengths,
blank,
clamp,
fused_log_softmax,
reuse_logits_for_grads);
fused_log_softmax);
auto costs = std::get<0>(result);
auto grads = std::get<1>(result).value_or(undef);
ctx->save_for_backward({grads});
Expand All @@ -51,8 +49,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true,
bool reuse_logits_for_grads = true) {
bool fused_log_softmax = true) {
at::AutoDispatchBelowADInplaceOrView guard;
auto results = RNNTLossFunction::apply(
logits,
Expand All @@ -61,8 +58,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
target_lengths,
blank,
clamp,
fused_log_softmax,
reuse_logits_for_grads);
fused_log_softmax);
return std::make_tuple(results[0], results[1]);
}

Expand Down
9 changes: 3 additions & 6 deletions torchaudio/csrc/rnnt/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true,
bool reuse_logits_for_grads = true) {
bool fused_log_softmax = true) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torchaudio::rnnt_loss", "")
.typed<decltype(rnnt_loss)>();
Expand All @@ -20,8 +19,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
target_lengths,
blank,
clamp,
fused_log_softmax,
reuse_logits_for_grads);
fused_log_softmax);
}

TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
Expand All @@ -32,6 +30,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
"Tensor target_lengths,"
"int blank,"
"float clamp,"
"bool fused_log_softmax=True,"
"bool reuse_logits_for_grads=True) -> (Tensor, Tensor?)");
"bool fused_log_softmax=True) -> (Tensor, Tensor?)");
}
3 changes: 1 addition & 2 deletions torchaudio/csrc/rnnt/compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,4 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax,
bool reuse_logits_for_grads);
bool fused_log_softmax);
9 changes: 2 additions & 7 deletions torchaudio/csrc/rnnt/cpu/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true,
bool reuse_logits_for_grads = true) {
bool fused_log_softmax = true) {
TORCH_CHECK(
logits.device().type() == targets.device().type(),
"logits and targets must be on the same device");
Expand Down Expand Up @@ -92,11 +91,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
c10::optional<torch::Tensor> gradients = c10::nullopt;
if (logits.requires_grad()) {
if (reuse_logits_for_grads) {
gradients = logits;
} else {
gradients = torch::zeros_like(logits);
}
gradients = torch::zeros_like(logits);
}

torch::Tensor int_workspace = torch::empty(
Expand Down
9 changes: 2 additions & 7 deletions torchaudio/csrc/rnnt/gpu/compute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true,
bool reuse_logits_for_grads = true) {
bool fused_log_softmax = true) {
TORCH_CHECK(
logits.device().type() == targets.device().type(),
"logits and targets must be on the same device");
Expand Down Expand Up @@ -95,11 +94,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
c10::optional<torch::Tensor> gradients = c10::nullopt;
if (logits.requires_grad()) {
if (reuse_logits_for_grads) {
gradients = logits;
} else {
gradients = torch::zeros_like(logits);
}
gradients = torch::zeros_like(logits);
}

torch::Tensor int_workspace = torch::empty(
Expand Down
13 changes: 2 additions & 11 deletions torchaudio/prototype/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def rnnt_loss(
blank: int = -1,
clamp: float = -1,
fused_log_softmax: bool = True,
reuse_logits_for_grads: bool = True,
reduction: str = "mean",
):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
Expand All @@ -33,7 +32,6 @@ def rnnt_loss(
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)

Expand All @@ -46,9 +44,6 @@ def rnnt_loss(

if not fused_log_softmax:
logits = torch.nn.functional.log_softmax(logits, dim=-1)
reuse_logits_for_grads = (
False # softmax needs the original logits value
)

if blank < 0: # reinterpret blank index if blank < 0.
blank = logits.shape[-1] + blank
Expand All @@ -60,8 +55,8 @@ def rnnt_loss(
target_lengths=target_lengths,
blank=blank,
clamp=clamp,
fused_log_softmax=fused_log_softmax,
reuse_logits_for_grads=reuse_logits_for_grads,)
fused_log_softmax=fused_log_softmax
)

if reduction == 'mean':
return costs.mean()
Expand All @@ -83,7 +78,6 @@ class RNNTLoss(torch.nn.Module):
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
"""
Expand All @@ -93,14 +87,12 @@ def __init__(
blank: int = -1,
clamp: float = -1.,
fused_log_softmax: bool = True,
reuse_logits_for_grads: bool = True,
reduction: str = "mean",
):
super().__init__()
self.blank = blank
self.clamp = clamp
self.fused_log_softmax = fused_log_softmax
self.reuse_logits_for_grads = reuse_logits_for_grads
self.reduction = reduction

def forward(
Expand Down Expand Up @@ -129,6 +121,5 @@ def forward(
self.blank,
self.clamp,
self.fused_log_softmax,
self.reuse_logits_for_grads,
self.reduction
)