From 51d36fb6b62f3f3986864e25b4709bc366a2bbfd Mon Sep 17 00:00:00 2001 From: Caroline Chen Date: Thu, 24 Jun 2021 23:05:22 +0000 Subject: [PATCH 1/2] remove reuse_logits_for_grads --- test/torchaudio_unittest/rnnt/autograd_impl.py | 3 +-- test/torchaudio_unittest/rnnt/rnnt_loss_impl.py | 14 ++++++-------- test/torchaudio_unittest/rnnt/utils.py | 3 +-- torchaudio/csrc/rnnt/autograd.cpp | 12 ++++-------- torchaudio/csrc/rnnt/compute.cpp | 9 +++------ torchaudio/csrc/rnnt/compute.h | 3 +-- torchaudio/csrc/rnnt/cpu/compute.cpp | 9 ++------- torchaudio/csrc/rnnt/gpu/compute.cu | 9 ++------- torchaudio/prototype/rnnt_loss.py | 13 ++----------- 9 files changed, 22 insertions(+), 53 deletions(-) diff --git a/test/torchaudio_unittest/rnnt/autograd_impl.py b/test/torchaudio_unittest/rnnt/autograd_impl.py index 1d5bf40a3c..8d0ebf9cb1 100644 --- a/test/torchaudio_unittest/rnnt/autograd_impl.py +++ b/test/torchaudio_unittest/rnnt/autograd_impl.py @@ -53,7 +53,7 @@ def test_RNNTLoss_gradcheck(self, data_func): data["logit_lengths"], data["target_lengths"], ) - loss = RNNTLoss(blank=data["blank"], reuse_logits_for_grads=False) + loss = RNNTLoss(blank=data["blank"]) self.assert_grad(loss, inputs, enable_all_grad=False) @@ -72,7 +72,6 @@ def test_rnnt_loss_gradcheck(self, data_func): data["blank"], # blank -1, # clamp True, # fused_log_softmax - False, # reuse_logits_for_grads ) self.assert_grad(rnnt_loss, inputs, enable_all_grad=False) diff --git a/test/torchaudio_unittest/rnnt/rnnt_loss_impl.py b/test/torchaudio_unittest/rnnt/rnnt_loss_impl.py index 591ed3bcd9..8903637abb 100644 --- a/test/torchaudio_unittest/rnnt/rnnt_loss_impl.py +++ b/test/torchaudio_unittest/rnnt/rnnt_loss_impl.py @@ -17,14 +17,12 @@ def _test_costs_and_gradients( self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2 ): logits_shape = data["logits"].shape - for reuse_logits_for_grads in [False, True]: - with self.subTest(reuse_logits_for_grads=reuse_logits_for_grads): - costs, gradients = compute_with_pytorch_transducer( - data=data, reuse_logits_for_grads=reuse_logits_for_grads - ) - self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol) - self.assertEqual(logits_shape, gradients.shape) - self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol) + costs, gradients = compute_with_pytorch_transducer( + data=data, reuse_logits_for_grads=reuse_logits_for_grads + ) + self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol) + self.assertEqual(logits_shape, gradients.shape) + self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol) def test_basic_backward(self): rnnt_loss = RNNTLoss() diff --git a/test/torchaudio_unittest/rnnt/utils.py b/test/torchaudio_unittest/rnnt/utils.py index 3ae7f1913a..ec2933bf1e 100644 --- a/test/torchaudio_unittest/rnnt/utils.py +++ b/test/torchaudio_unittest/rnnt/utils.py @@ -23,11 +23,10 @@ def compute_with_numpy_transducer(data): return costs, gradients -def compute_with_pytorch_transducer(data, reuse_logits_for_grads=False): +def compute_with_pytorch_transducer(data): costs = RNNTLoss( blank=data["blank"], fused_log_softmax=data.get("fused_log_softmax", True), - reuse_logits_for_grads=reuse_logits_for_grads, reduction="none", )( logits=data["logits"], diff --git a/torchaudio/csrc/rnnt/autograd.cpp b/torchaudio/csrc/rnnt/autograd.cpp index 9f0d64822f..59f3e9ebeb 100644 --- a/torchaudio/csrc/rnnt/autograd.cpp +++ b/torchaudio/csrc/rnnt/autograd.cpp @@ -14,8 +14,7 @@ class RNNTLossFunction : public torch::autograd::Function { const torch::Tensor& target_lengths, int64_t blank, double clamp, - bool fused_log_softmax = true, - bool reuse_logits_for_grads = true) { + bool fused_log_softmax = true) { torch::Tensor undef; auto result = rnnt_loss( logits, @@ -24,8 +23,7 @@ class RNNTLossFunction : public torch::autograd::Function { target_lengths, blank, clamp, - fused_log_softmax, - reuse_logits_for_grads); + fused_log_softmax); auto costs = std::get<0>(result); auto grads = std::get<1>(result).value_or(undef); ctx->save_for_backward({grads}); @@ -51,8 +49,7 @@ std::tuple> rnnt_loss_autograd( const torch::Tensor& target_lengths, int64_t blank, double clamp, - bool fused_log_softmax = true, - bool reuse_logits_for_grads = true) { + bool fused_log_softmax = true) { at::AutoDispatchBelowADInplaceOrView guard; auto results = RNNTLossFunction::apply( logits, @@ -61,8 +58,7 @@ std::tuple> rnnt_loss_autograd( target_lengths, blank, clamp, - fused_log_softmax, - reuse_logits_for_grads); + fused_log_softmax); return std::make_tuple(results[0], results[1]); } diff --git a/torchaudio/csrc/rnnt/compute.cpp b/torchaudio/csrc/rnnt/compute.cpp index d54fa4b896..f21413e432 100644 --- a/torchaudio/csrc/rnnt/compute.cpp +++ b/torchaudio/csrc/rnnt/compute.cpp @@ -8,8 +8,7 @@ std::tuple> rnnt_loss( const torch::Tensor& target_lengths, int64_t blank, double clamp, - bool fused_log_softmax = true, - bool reuse_logits_for_grads = true) { + bool fused_log_softmax = true) { static auto op = torch::Dispatcher::singleton() .findSchemaOrThrow("torchaudio::rnnt_loss", "") .typed(); @@ -20,8 +19,7 @@ std::tuple> rnnt_loss( target_lengths, blank, clamp, - fused_log_softmax, - reuse_logits_for_grads); + fused_log_softmax); } TORCH_LIBRARY_FRAGMENT(torchaudio, m) { @@ -32,6 +30,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { "Tensor target_lengths," "int blank," "float clamp," - "bool fused_log_softmax=True," - "bool reuse_logits_for_grads=True) -> (Tensor, Tensor?)"); + "bool fused_log_softmax=True) -> (Tensor, Tensor?)"); } diff --git a/torchaudio/csrc/rnnt/compute.h b/torchaudio/csrc/rnnt/compute.h index f913a8599b..0508cec80e 100644 --- a/torchaudio/csrc/rnnt/compute.h +++ b/torchaudio/csrc/rnnt/compute.h @@ -9,5 +9,4 @@ std::tuple> rnnt_loss( const torch::Tensor& target_lengths, int64_t blank, double clamp, - bool fused_log_softmax, - bool reuse_logits_for_grads); + bool fused_log_softmax); diff --git a/torchaudio/csrc/rnnt/cpu/compute.cpp b/torchaudio/csrc/rnnt/cpu/compute.cpp index 47315c69d9..5d6df5aeab 100644 --- a/torchaudio/csrc/rnnt/cpu/compute.cpp +++ b/torchaudio/csrc/rnnt/cpu/compute.cpp @@ -13,8 +13,7 @@ std::tuple> compute( const torch::Tensor& target_lengths, int64_t blank, double clamp, - bool fused_log_softmax = true, - bool reuse_logits_for_grads = true) { + bool fused_log_softmax = true) { TORCH_CHECK( logits.device().type() == targets.device().type(), "logits and targets must be on the same device"); @@ -92,11 +91,7 @@ std::tuple> compute( torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); c10::optional gradients = c10::nullopt; if (logits.requires_grad()) { - if (reuse_logits_for_grads) { - gradients = logits; - } else { - gradients = torch::zeros_like(logits); - } + gradients = torch::zeros_like(logits); } torch::Tensor int_workspace = torch::empty( diff --git a/torchaudio/csrc/rnnt/gpu/compute.cu b/torchaudio/csrc/rnnt/gpu/compute.cu index 5d02122277..0e7badabe4 100644 --- a/torchaudio/csrc/rnnt/gpu/compute.cu +++ b/torchaudio/csrc/rnnt/gpu/compute.cu @@ -14,8 +14,7 @@ std::tuple> compute( const torch::Tensor& target_lengths, int64_t blank, double clamp, - bool fused_log_softmax = true, - bool reuse_logits_for_grads = true) { + bool fused_log_softmax = true) { TORCH_CHECK( logits.device().type() == targets.device().type(), "logits and targets must be on the same device"); @@ -95,11 +94,7 @@ std::tuple> compute( torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); c10::optional gradients = c10::nullopt; if (logits.requires_grad()) { - if (reuse_logits_for_grads) { - gradients = logits; - } else { - gradients = torch::zeros_like(logits); - } + gradients = torch::zeros_like(logits); } torch::Tensor int_workspace = torch::empty( diff --git a/torchaudio/prototype/rnnt_loss.py b/torchaudio/prototype/rnnt_loss.py index 48f7fcdbb7..ffaee638c5 100644 --- a/torchaudio/prototype/rnnt_loss.py +++ b/torchaudio/prototype/rnnt_loss.py @@ -15,7 +15,6 @@ def rnnt_loss( blank: int = -1, clamp: float = -1, fused_log_softmax: bool = True, - reuse_logits_for_grads: bool = True, reduction: str = "mean", ): """Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks* @@ -33,7 +32,6 @@ def rnnt_loss( blank (int, opt): blank label (Default: ``-1``) clamp (float): clamp for gradients (Default: ``-1``) fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``) - reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``) reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``) @@ -46,9 +44,6 @@ def rnnt_loss( if not fused_log_softmax: logits = torch.nn.functional.log_softmax(logits, dim=-1) - reuse_logits_for_grads = ( - False # softmax needs the original logits value - ) if blank < 0: # reinterpret blank index if blank < 0. blank = logits.shape[-1] + blank @@ -60,8 +55,8 @@ def rnnt_loss( target_lengths=target_lengths, blank=blank, clamp=clamp, - fused_log_softmax=fused_log_softmax, - reuse_logits_for_grads=reuse_logits_for_grads,) + fused_log_softmax=fused_log_softmax + ) if reduction == 'mean': return costs.mean() @@ -83,7 +78,6 @@ class RNNTLoss(torch.nn.Module): blank (int, opt): blank label (Default: ``-1``) clamp (float): clamp for gradients (Default: ``-1``) fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``) - reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``) reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``) """ @@ -93,14 +87,12 @@ def __init__( blank: int = -1, clamp: float = -1., fused_log_softmax: bool = True, - reuse_logits_for_grads: bool = True, reduction: str = "mean", ): super().__init__() self.blank = blank self.clamp = clamp self.fused_log_softmax = fused_log_softmax - self.reuse_logits_for_grads = reuse_logits_for_grads self.reduction = reduction def forward( @@ -129,6 +121,5 @@ def forward( self.blank, self.clamp, self.fused_log_softmax, - self.reuse_logits_for_grads, self.reduction ) From d55aed7e9f97d17d87b5269726a5f3eb80d0e233 Mon Sep 17 00:00:00 2001 From: Caroline Chen Date: Fri, 9 Jul 2021 19:05:57 +0000 Subject: [PATCH 2/2] remove unncessary code --- test/torchaudio_unittest/rnnt/rnnt_loss_impl.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/torchaudio_unittest/rnnt/rnnt_loss_impl.py b/test/torchaudio_unittest/rnnt/rnnt_loss_impl.py index 8903637abb..a9ca72951c 100644 --- a/test/torchaudio_unittest/rnnt/rnnt_loss_impl.py +++ b/test/torchaudio_unittest/rnnt/rnnt_loss_impl.py @@ -17,9 +17,7 @@ def _test_costs_and_gradients( self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2 ): logits_shape = data["logits"].shape - costs, gradients = compute_with_pytorch_transducer( - data=data, reuse_logits_for_grads=reuse_logits_for_grads - ) + costs, gradients = compute_with_pytorch_transducer(data=data) self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol) self.assertEqual(logits_shape, gradients.shape) self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)