Skip to content

Conversation

vincentqb
Copy link
Owner

@vincentqb vincentqb commented May 27, 2021

Follow-up to pytorch#1532

  • Replace subTest.
  • mark_dirty when reuse_logits_for_grad=True
  • Change default to reuse_logits_for_grad=False
  • Change default to fused_logsoftmax=False

@vincentqb vincentqb marked this pull request as draft May 27, 2021 20:04
@vincentqb vincentqb force-pushed the rnntautograd_default branch from e68ed4a to 397d20a Compare May 27, 2021 21:58
@vincentqb
Copy link
Owner Author

If an input needs to be modified in place (e.g. when reuse_logits_for_grads=True), we may need to use mark_dirty, see doc.

diff --git a/torchaudio/csrc/rnnt/autograd.cpp b/torchaudio/csrc/rnnt/autograd.cpp
index 6147756..8bf11f1 100644
--- a/torchaudio/csrc/rnnt/autograd.cpp
+++ b/torchaudio/csrc/rnnt/autograd.cpp
@@ -18,6 +18,9 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
       bool reuse_logits_for_grads = false) {
     at::AutoNonVariableTypeMode g;
     torch::Tensor undef;
+    if (reuse_logits_for_grads) {
+        ctx->mark_dirty({logits});
+    }
     auto result = rnnt_loss(
         logits,
         targets,

This however leads to

FAILED rnnt_loss_cuda_test.py::TestRNNTLoss::test_costs_and_gradients_B1_T2_U3_D5_fp32_1 - RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

@vincentqb
Copy link
Owner Author

Experiment with mark_dirty as mentioned in previous comment

index 3083efb..944a038 100644
--- a/test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
+++ b/test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
@@ -50,7 +50,7 @@ class RNNTLossTest:
         data, ref_costs, ref_gradients = get_numpy_data_B1_T2_U3_D5(
             dtype=np.float32
         )
-        data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
+        data = numpy_to_torch(data=data, device=self.device, requires_grad=not reuse_logits_for_grads)
         data["reuse_logits_for_grads"] = reuse_logits_for_grads
         self._test_costs_and_gradients(
             data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
diff --git a/test/torchaudio_unittest/rnnt/utils.py b/test/torchaudio_unittest/rnnt/utils.py
index 0a04104..4026e78 100644
--- a/test/torchaudio_unittest/rnnt/utils.py
+++ b/test/torchaudio_unittest/rnnt/utils.py
@@ -432,7 +432,8 @@ def numpy_to_torch(data, device, requires_grad=True):
     def grad_hook(grad):
         logits.saved_grad = grad.clone()
 
-    logits.register_hook(grad_hook)
+    if requires_grad:
+        logits.register_hook(grad_hook)
 
     data["logits"] = logits
     data["logit_lengths"] = logit_lengths
============================================================================================================= short test summary info ==============================================================================================================FAILED rnnt_loss_cpu_test.py::TestRNNTLoss::test_costs_and_gradients_B1_T2_U3_D5_fp16_1 - RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.
FAILED rnnt_loss_cpu_test.py::TestRNNTLoss::test_costs_and_gradients_B1_T2_U3_D5_fp32_1 - RuntimeError: cannot call get_autograd_meta() on undefined tensor
FAILED rnnt_loss_cpu_test.py::TestRNNTLoss::test_costs_and_gradients_B2_T4_U3_D3_fp16_1 - RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.
FAILED rnnt_loss_cpu_test.py::TestRNNTLoss::test_costs_and_gradients_B2_T4_U3_D3_fp32_1 - RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.
FAILED rnnt_loss_cpu_test.py::TestRNNTLoss::test_costs_and_gradients_random_data_with_numpy_fp32_1 - RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.
FAILED rnnt_loss_cuda_test.py::TestRNNTLoss::test_costs_and_gradients_B1_T2_U3_D5_fp16_1 - RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.
FAILED rnnt_loss_cuda_test.py::TestRNNTLoss::test_costs_and_gradients_B1_T2_U3_D5_fp32_1 - RuntimeError: cannot call get_autograd_meta() on undefined tensor
FAILED rnnt_loss_cuda_test.py::TestRNNTLoss::test_costs_and_gradients_B2_T4_U3_D3_fp16_1 - RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.
FAILED rnnt_loss_cuda_test.py::TestRNNTLoss::test_costs_and_gradients_B2_T4_U3_D3_fp32_1 - RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.
FAILED rnnt_loss_cuda_test.py::TestRNNTLoss::test_costs_and_gradients_random_data_with_numpy_fp32_1 - RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.
==================================================================================================== 10 failed, 30 passed, 12 warnings in 5.88s ====================================================================================================
Details

============================= test session starts ==============================
platform linux -- Python 3.8.10, pytest-6.2.4, py-1.10.0, pluggy-0.13.1 -- /private/home/vincentqb/miniconda/envs/torch-nightly/bin/python
cachedir: .pytest_cache
rootdir: /private/home/vincentqb/autograd/audio
plugins: hydra-core-1.0.6
collecting ... collected 40 items

autograd_cpu_test.py::TestAutograd::test_RNNTLoss_gradcheck_0 PASSED     [  2%]
autograd_cpu_test.py::TestAutograd::test_RNNTLoss_gradcheck_1 PASSED     [  5%]
autograd_cpu_test.py::TestAutograd::test_RNNTLoss_gradcheck_2 PASSED     [  7%]
autograd_cpu_test.py::TestAutograd::test_np_transducer_gradcheck_0 PASSED [ 10%]
autograd_cpu_test.py::TestAutograd::test_np_transducer_gradcheck_1 PASSED [ 12%]
autograd_cpu_test.py::TestAutograd::test_np_transducer_gradcheck_2 PASSED [ 15%]
autograd_cuda_test.py::TestAutograd::test_RNNTLoss_gradcheck_0 PASSED    [ 17%]
autograd_cuda_test.py::TestAutograd::test_RNNTLoss_gradcheck_1 PASSED    [ 20%]
autograd_cuda_test.py::TestAutograd::test_RNNTLoss_gradcheck_2 PASSED    [ 22%]
autograd_cuda_test.py::TestAutograd::test_np_transducer_gradcheck_0 PASSED [ 25%]
autograd_cuda_test.py::TestAutograd::test_np_transducer_gradcheck_1 PASSED [ 27%]
autograd_cuda_test.py::TestAutograd::test_np_transducer_gradcheck_2 PASSED [ 30%]
rnnt_loss_cpu_test.py::TestRNNTLoss::test_basic_backward PASSED          [ 32%]
rnnt_loss_cpu_test.py::TestRNNTLoss::test_costs_and_gradients_B1_T2_U3_D5_fp16_0 PASSED [ 35%]
rnnt_loss_cpu_test.py::TestRNNTLoss::test_costs_and_gradients_B1_T2_U3_D5_fp16_1 FAILED [ 37%]
rnnt_loss_cpu_test.py::TestRNNTLoss::test_costs_and_gradients_B1_T2_U3_D5_fp32_0 PASSED [ 40%]
rnnt_loss_cpu_test.py::TestRNNTLoss::test_costs_and_gradients_B1_T2_U3_D5_fp32_1 FAILED [ 42%]
rnnt_loss_cpu_test.py::TestRNNTLoss::test_costs_and_gradients_B2_T4_U3_D3_fp16_0 PASSED [ 45%]
rnnt_loss_cpu_test.py::TestRNNTLoss::test_costs_and_gradients_B2_T4_U3_D3_fp16_1 FAILED [ 47%]
rnnt_loss_cpu_test.py::TestRNNTLoss::test_costs_and_gradients_B2_T4_U3_D3_fp32_0 PASSED [ 50%]
rnnt_loss_cpu_test.py::TestRNNTLoss::test_costs_and_gradients_B2_T4_U3_D3_fp32_1 FAILED [ 52%]
rnnt_loss_cpu_test.py::TestRNNTLoss::test_costs_and_gradients_random_data_with_numpy_fp32_0 PASSED [ 55%]
rnnt_loss_cpu_test.py::TestRNNTLoss::test_costs_and_gradients_random_data_with_numpy_fp32_1 FAILED [ 57%]
rnnt_loss_cpu_test.py::TestRNNTLoss::test_rnnt_nonfused_log_softmax PASSED [ 60%]
rnnt_loss_cuda_test.py::TestRNNTLoss::test_basic_backward PASSED         [ 62%]
rnnt_loss_cuda_test.py::TestRNNTLoss::test_costs_and_gradients_B1_T2_U3_D5_fp16_0 PASSED [ 65%]
rnnt_loss_cuda_test.py::TestRNNTLoss::test_costs_and_gradients_B1_T2_U3_D5_fp16_1 FAILED [ 67%]
rnnt_loss_cuda_test.py::TestRNNTLoss::test_costs_and_gradients_B1_T2_U3_D5_fp32_0 PASSED [ 70%]
rnnt_loss_cuda_test.py::TestRNNTLoss::test_costs_and_gradients_B1_T2_U3_D5_fp32_1 FAILED [ 72%]
rnnt_loss_cuda_test.py::TestRNNTLoss::test_costs_and_gradients_B2_T4_U3_D3_fp16_0 PASSED [ 75%]
rnnt_loss_cuda_test.py::TestRNNTLoss::test_costs_and_gradients_B2_T4_U3_D3_fp16_1 FAILED [ 77%]
rnnt_loss_cuda_test.py::TestRNNTLoss::test_costs_and_gradients_B2_T4_U3_D3_fp32_0 PASSED [ 80%]
rnnt_loss_cuda_test.py::TestRNNTLoss::test_costs_and_gradients_B2_T4_U3_D3_fp32_1 FAILED [ 82%]
rnnt_loss_cuda_test.py::TestRNNTLoss::test_costs_and_gradients_random_data_with_numpy_fp32_0 PASSED [ 85%]
rnnt_loss_cuda_test.py::TestRNNTLoss::test_costs_and_gradients_random_data_with_numpy_fp32_1 FAILED [ 87%]
rnnt_loss_cuda_test.py::TestRNNTLoss::test_rnnt_nonfused_log_softmax PASSED [ 90%]
torchscript_consistency_cpu_test.py::TestRNNTLoss::test_RNNTLoss PASSED  [ 92%]
torchscript_consistency_cpu_test.py::TestRNNTLoss::test_rnnt_loss PASSED [ 95%]
torchscript_consistency_cuda_test.py::TestRNNTLoss::test_RNNTLoss PASSED [ 97%]
torchscript_consistency_cuda_test.py::TestRNNTLoss::test_rnnt_loss PASSED [100%]

=================================== FAILURES ===================================
___________ TestRNNTLoss.test_costs_and_gradients_B1_T2_U3_D5_fp16_1 ___________

a = (<torchaudio_unittest.rnnt.rnnt_loss_cpu_test.TestRNNTLoss testMethod=test_costs_and_gradients_B1_T2_U3_D5_fp16_1>,)

    @wraps(func)
    def standalone_func(*a):
>       return func(*(a + p.args), **p.kwargs)

../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/parameterized/parameterized.py:533: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
rnnt_loss_impl.py:69: in test_costs_and_gradients_B1_T2_U3_D5_fp16
    self._test_costs_and_gradients(
rnnt_loss_impl.py:22: in _test_costs_and_gradients
    costs, gradients = compute_with_pytorch_transducer(data=data)
utils.py:30: in compute_with_pytorch_transducer
    costs = RNNTLoss(
../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torch/nn/modules/module.py:1051: in _call_impl
    return forward_call(*input, **kwargs)
../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torchaudio/prototype/rnnt_loss.py:101: in forward
    return rnnt_loss(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

logits = tensor([[[[ 0.1770, -0.3999,  0.1770,  0.1770, -0.1311],
          [ 0.1225,  0.1225, -0.1818,  0.1225, -0.1857],
    ...,  0.1207],
          [ 0.3074,  0.1687,  0.1864,  0.1687, -0.8315]]]],
       dtype=torch.float16, requires_grad=True)
targets = tensor([[1, 2]], dtype=torch.int32)
logit_lengths = tensor([2], dtype=torch.int32)
target_lengths = tensor([2], dtype=torch.int32), blank = 4, clamp = -1.0
fused_log_softmax = True, reuse_logits_for_grads = True

    def rnnt_loss(
        logits: Tensor,
        targets: Tensor,
        logit_lengths: Tensor,
        target_lengths: Tensor,
        blank: int = -1,
        clamp: float = -1,
        fused_log_softmax: bool = True,
        reuse_logits_for_grads: bool = False,
    ):
        """
        Compute the RNN Transducer Loss.
    
        The RNN Transducer loss (`Graves 2012 <https://arxiv.org/pdf/1211.3711.pdf>`__) extends the CTC loss by defining
        a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output
        dependencies.
    
        Args:
            logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
            targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
            logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
            target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
            blank (int, opt): blank label (Default: ``-1``)
            clamp (float): clamp for gradients (Default: ``-1``)
            runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``)
            fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
    j       reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``False``)
        """
        if not fused_log_softmax:
            logits = torch.nn.functional.log_softmax(logits, dim=-1)
            # softmax needs the original logits value
            assert not reuse_logits_for_grads, "fused_log_softmax=True requires reuse_logits_for_grads=False"
    
        if blank < 0:  # reinterpret blank index if blank < 0.
            blank = logits.shape[-1] + blank
    
>       costs, gradients = torch.ops.torchaudio.rnnt_loss(
            logits=logits,
            targets=targets,
            src_lengths=logit_lengths,
            tgt_lengths=target_lengths,
            blank=blank,
            clamp=clamp,
            fused_log_smax=fused_log_softmax,
            reuse_logits_for_grads=reuse_logits_for_grads,)
E       RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torchaudio/prototype/rnnt_loss.py:46: RuntimeError
___________ TestRNNTLoss.test_costs_and_gradients_B1_T2_U3_D5_fp32_1 ___________

a = (<torchaudio_unittest.rnnt.rnnt_loss_cpu_test.TestRNNTLoss testMethod=test_costs_and_gradients_B1_T2_U3_D5_fp32_1>,)

    @wraps(func)
    def standalone_func(*a):
>       return func(*(a + p.args), **p.kwargs)

../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/parameterized/parameterized.py:533: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
rnnt_loss_impl.py:55: in test_costs_and_gradients_B1_T2_U3_D5_fp32
    self._test_costs_and_gradients(
rnnt_loss_impl.py:22: in _test_costs_and_gradients
    costs, gradients = compute_with_pytorch_transducer(data=data)
utils.py:30: in compute_with_pytorch_transducer
    costs = RNNTLoss(
../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torch/nn/modules/module.py:1051: in _call_impl
    return forward_call(*input, **kwargs)
../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torchaudio/prototype/rnnt_loss.py:101: in forward
    return rnnt_loss(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

logits = tensor([[[[0.1000, 0.6000, 0.1000, 0.1000, 0.1000],
          [0.1000, 0.1000, 0.6000, 0.1000, 0.1000],
          [0.1...00, 0.1000],
          [0.1000, 0.1000, 0.2000, 0.1000, 0.1000],
          [0.7000, 0.1000, 0.2000, 0.1000, 0.1000]]]])
targets = tensor([[1, 2]], dtype=torch.int32)
logit_lengths = tensor([2], dtype=torch.int32)
target_lengths = tensor([2], dtype=torch.int32), blank = 4, clamp = -1.0
fused_log_softmax = True, reuse_logits_for_grads = True

    def rnnt_loss(
        logits: Tensor,
        targets: Tensor,
        logit_lengths: Tensor,
        target_lengths: Tensor,
        blank: int = -1,
        clamp: float = -1,
        fused_log_softmax: bool = True,
        reuse_logits_for_grads: bool = False,
    ):
        """
        Compute the RNN Transducer Loss.
    
        The RNN Transducer loss (`Graves 2012 <https://arxiv.org/pdf/1211.3711.pdf>`__) extends the CTC loss by defining
        a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output
        dependencies.
    
        Args:
            logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
            targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
            logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
            target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
            blank (int, opt): blank label (Default: ``-1``)
            clamp (float): clamp for gradients (Default: ``-1``)
            runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``)
            fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
    j       reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``False``)
        """
        if not fused_log_softmax:
            logits = torch.nn.functional.log_softmax(logits, dim=-1)
            # softmax needs the original logits value
            assert not reuse_logits_for_grads, "fused_log_softmax=True requires reuse_logits_for_grads=False"
    
        if blank < 0:  # reinterpret blank index if blank < 0.
            blank = logits.shape[-1] + blank
    
>       costs, gradients = torch.ops.torchaudio.rnnt_loss(
            logits=logits,
            targets=targets,
            src_lengths=logit_lengths,
            tgt_lengths=target_lengths,
            blank=blank,
            clamp=clamp,
            fused_log_smax=fused_log_softmax,
            reuse_logits_for_grads=reuse_logits_for_grads,)
E       RuntimeError: cannot call get_autograd_meta() on undefined tensor

../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torchaudio/prototype/rnnt_loss.py:46: RuntimeError
___________ TestRNNTLoss.test_costs_and_gradients_B2_T4_U3_D3_fp16_1 ___________

a = (<torchaudio_unittest.rnnt.rnnt_loss_cpu_test.TestRNNTLoss testMethod=test_costs_and_gradients_B2_T4_U3_D3_fp16_1>,)

    @wraps(func)
    def standalone_func(*a):
>       return func(*(a + p.args), **p.kwargs)

../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/parameterized/parameterized.py:533: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
rnnt_loss_impl.py:101: in test_costs_and_gradients_B2_T4_U3_D3_fp16
    self._test_costs_and_gradients(
rnnt_loss_impl.py:22: in _test_costs_and_gradients
    costs, gradients = compute_with_pytorch_transducer(data=data)
utils.py:30: in compute_with_pytorch_transducer
    costs = RNNTLoss(
../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torch/nn/modules/module.py:1051: in _call_impl
    return forward_call(*input, **kwargs)
../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torchaudio/prototype/rnnt_loss.py:101: in forward
    return rnnt_loss(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

logits = tensor([[[[-1.8677e-01, -6.2622e-02,  2.4939e-01],
          [-2.0337e-01,  2.0239e-01,  9.6893e-04],
          [-1.40...01,  1.1646e-01],
          [-5.9863e-01,  3.0225e-01,  2.9639e-01]]]], dtype=torch.float16,
       requires_grad=True)
targets = tensor([[1, 2],
        [1, 1]], dtype=torch.int32)
logit_lengths = tensor([4, 4], dtype=torch.int32)
target_lengths = tensor([2, 2], dtype=torch.int32), blank = 0, clamp = -1.0
fused_log_softmax = True, reuse_logits_for_grads = True

    def rnnt_loss(
        logits: Tensor,
        targets: Tensor,
        logit_lengths: Tensor,
        target_lengths: Tensor,
        blank: int = -1,
        clamp: float = -1,
        fused_log_softmax: bool = True,
        reuse_logits_for_grads: bool = False,
    ):
        """
        Compute the RNN Transducer Loss.
    
        The RNN Transducer loss (`Graves 2012 <https://arxiv.org/pdf/1211.3711.pdf>`__) extends the CTC loss by defining
        a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output
        dependencies.
    
        Args:
            logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
            targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
            logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
            target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
            blank (int, opt): blank label (Default: ``-1``)
            clamp (float): clamp for gradients (Default: ``-1``)
            runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``)
            fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
    j       reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``False``)
        """
        if not fused_log_softmax:
            logits = torch.nn.functional.log_softmax(logits, dim=-1)
            # softmax needs the original logits value
            assert not reuse_logits_for_grads, "fused_log_softmax=True requires reuse_logits_for_grads=False"
    
        if blank < 0:  # reinterpret blank index if blank < 0.
            blank = logits.shape[-1] + blank
    
>       costs, gradients = torch.ops.torchaudio.rnnt_loss(
            logits=logits,
            targets=targets,
            src_lengths=logit_lengths,
            tgt_lengths=target_lengths,
            blank=blank,
            clamp=clamp,
            fused_log_smax=fused_log_softmax,
            reuse_logits_for_grads=reuse_logits_for_grads,)
E       RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torchaudio/prototype/rnnt_loss.py:46: RuntimeError
___________ TestRNNTLoss.test_costs_and_gradients_B2_T4_U3_D3_fp32_1 ___________

a = (<torchaudio_unittest.rnnt.rnnt_loss_cpu_test.TestRNNTLoss testMethod=test_costs_and_gradients_B2_T4_U3_D3_fp32_1>,)

    @wraps(func)
    def standalone_func(*a):
>       return func(*(a + p.args), **p.kwargs)

../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/parameterized/parameterized.py:533: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
rnnt_loss_impl.py:87: in test_costs_and_gradients_B2_T4_U3_D3_fp32
    self._test_costs_and_gradients(
rnnt_loss_impl.py:22: in _test_costs_and_gradients
    costs, gradients = compute_with_pytorch_transducer(data=data)
utils.py:30: in compute_with_pytorch_transducer
    costs = RNNTLoss(
../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torch/nn/modules/module.py:1051: in _call_impl
    return forward_call(*input, **kwargs)
../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torchaudio/prototype/rnnt_loss.py:101: in forward
    return rnnt_loss(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

logits = tensor([[[[-1.8684e-01, -6.2555e-02,  2.4940e-01],
          [-2.0338e-01,  2.0240e-01,  9.7743e-04],
          [-1.41...     [ 1.2233e-01, -2.3879e-01,  1.1646e-01],
          [-5.9869e-01,  3.0220e-01,  2.9648e-01]]]], requires_grad=True)
targets = tensor([[1, 2],
        [1, 1]], dtype=torch.int32)
logit_lengths = tensor([4, 4], dtype=torch.int32)
target_lengths = tensor([2, 2], dtype=torch.int32), blank = 0, clamp = -1.0
fused_log_softmax = True, reuse_logits_for_grads = True

    def rnnt_loss(
        logits: Tensor,
        targets: Tensor,
        logit_lengths: Tensor,
        target_lengths: Tensor,
        blank: int = -1,
        clamp: float = -1,
        fused_log_softmax: bool = True,
        reuse_logits_for_grads: bool = False,
    ):
        """
        Compute the RNN Transducer Loss.
    
        The RNN Transducer loss (`Graves 2012 <https://arxiv.org/pdf/1211.3711.pdf>`__) extends the CTC loss by defining
        a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output
        dependencies.
    
        Args:
            logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
            targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
            logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
            target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
            blank (int, opt): blank label (Default: ``-1``)
            clamp (float): clamp for gradients (Default: ``-1``)
            runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``)
            fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
    j       reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``False``)
        """
        if not fused_log_softmax:
            logits = torch.nn.functional.log_softmax(logits, dim=-1)
            # softmax needs the original logits value
            assert not reuse_logits_for_grads, "fused_log_softmax=True requires reuse_logits_for_grads=False"
    
        if blank < 0:  # reinterpret blank index if blank < 0.
            blank = logits.shape[-1] + blank
    
>       costs, gradients = torch.ops.torchaudio.rnnt_loss(
            logits=logits,
            targets=targets,
            src_lengths=logit_lengths,
            tgt_lengths=target_lengths,
            blank=blank,
            clamp=clamp,
            fused_log_smax=fused_log_softmax,
            reuse_logits_for_grads=reuse_logits_for_grads,)
E       RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torchaudio/prototype/rnnt_loss.py:46: RuntimeError
_____ TestRNNTLoss.test_costs_and_gradients_random_data_with_numpy_fp32_1 ______

a = (<torchaudio_unittest.rnnt.rnnt_loss_cpu_test.TestRNNTLoss testMethod=test_costs_and_gradients_random_data_with_numpy_fp32_1>,)

    @wraps(func)
    def standalone_func(*a):
>       return func(*(a + p.args), **p.kwargs)

../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/parameterized/parameterized.py:533: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
rnnt_loss_impl.py:120: in test_costs_and_gradients_random_data_with_numpy_fp32
    self._test_costs_and_gradients(
rnnt_loss_impl.py:22: in _test_costs_and_gradients
    costs, gradients = compute_with_pytorch_transducer(data=data)
utils.py:30: in compute_with_pytorch_transducer
    costs = RNNTLoss(
../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torch/nn/modules/module.py:1051: in _call_impl
    return forward_call(*input, **kwargs)
../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torchaudio/prototype/rnnt_loss.py:101: in forward
    return rnnt_loss(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

logits = tensor([[[[ 1.4938e-01,  7.5933e-02, -3.4599e-01,  ...,  6.9515e-02,
            1.0250e-01, -3.5009e-01],
          [...[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00]]]], requires_grad=True)
targets = tensor([[2, 4, 5, 7, 1, 7, 3, 7, 2, 2, 2, 6, 7, 5, 4, 0, 6, 2, 2],
        [4, 7, 7, 0, 0, 0, 3, 2, 2, 0, 3, 5, 3, 4, ... 2, 4, 2, 2, 4, 0, 1, 4],
        [5, 0, 5, 0, 7, 0, 6, 5, 7, 1, 6, 6, 5, 4, 7, 1, 0, 1, 4]],
       dtype=torch.int32)
logit_lengths = tensor([34, 30, 29, 23], dtype=torch.int32)
target_lengths = tensor([18, 19,  5,  6], dtype=torch.int32), blank = 8
clamp = -1.0, fused_log_softmax = True, reuse_logits_for_grads = True

    def rnnt_loss(
        logits: Tensor,
        targets: Tensor,
        logit_lengths: Tensor,
        target_lengths: Tensor,
        blank: int = -1,
        clamp: float = -1,
        fused_log_softmax: bool = True,
        reuse_logits_for_grads: bool = False,
    ):
        """
        Compute the RNN Transducer Loss.
    
        The RNN Transducer loss (`Graves 2012 <https://arxiv.org/pdf/1211.3711.pdf>`__) extends the CTC loss by defining
        a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output
        dependencies.
    
        Args:
            logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
            targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
            logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
            target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
            blank (int, opt): blank label (Default: ``-1``)
            clamp (float): clamp for gradients (Default: ``-1``)
            runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``)
            fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
    j       reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``False``)
        """
        if not fused_log_softmax:
            logits = torch.nn.functional.log_softmax(logits, dim=-1)
            # softmax needs the original logits value
            assert not reuse_logits_for_grads, "fused_log_softmax=True requires reuse_logits_for_grads=False"
    
        if blank < 0:  # reinterpret blank index if blank < 0.
            blank = logits.shape[-1] + blank
    
>       costs, gradients = torch.ops.torchaudio.rnnt_loss(
            logits=logits,
            targets=targets,
            src_lengths=logit_lengths,
            tgt_lengths=target_lengths,
            blank=blank,
            clamp=clamp,
            fused_log_smax=fused_log_softmax,
            reuse_logits_for_grads=reuse_logits_for_grads,)
E       RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torchaudio/prototype/rnnt_loss.py:46: RuntimeError
___________ TestRNNTLoss.test_costs_and_gradients_B1_T2_U3_D5_fp16_1 ___________

a = (<torchaudio_unittest.rnnt.rnnt_loss_cuda_test.TestRNNTLoss testMethod=test_costs_and_gradients_B1_T2_U3_D5_fp16_1>,)

    @wraps(func)
    def standalone_func(*a):
>       return func(*(a + p.args), **p.kwargs)

../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/parameterized/parameterized.py:533: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
rnnt_loss_impl.py:69: in test_costs_and_gradients_B1_T2_U3_D5_fp16
    self._test_costs_and_gradients(
rnnt_loss_impl.py:22: in _test_costs_and_gradients
    costs, gradients = compute_with_pytorch_transducer(data=data)
utils.py:30: in compute_with_pytorch_transducer
    costs = RNNTLoss(
../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torch/nn/modules/module.py:1051: in _call_impl
    return forward_call(*input, **kwargs)
../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torchaudio/prototype/rnnt_loss.py:101: in forward
    return rnnt_loss(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

logits = tensor([[[[ 0.1770, -0.3999,  0.1770,  0.1770, -0.1312],
          [ 0.1225,  0.1225, -0.1818,  0.1225, -0.1857],
    ...     [ 0.3074,  0.1687,  0.1864,  0.1687, -0.8315]]]], device='cuda:0',
       dtype=torch.float16, requires_grad=True)
targets = tensor([[1, 2]], device='cuda:0', dtype=torch.int32)
logit_lengths = tensor([2], device='cuda:0', dtype=torch.int32)
target_lengths = tensor([2], device='cuda:0', dtype=torch.int32), blank = 4
clamp = -1.0, fused_log_softmax = True, reuse_logits_for_grads = True

    def rnnt_loss(
        logits: Tensor,
        targets: Tensor,
        logit_lengths: Tensor,
        target_lengths: Tensor,
        blank: int = -1,
        clamp: float = -1,
        fused_log_softmax: bool = True,
        reuse_logits_for_grads: bool = False,
    ):
        """
        Compute the RNN Transducer Loss.
    
        The RNN Transducer loss (`Graves 2012 <https://arxiv.org/pdf/1211.3711.pdf>`__) extends the CTC loss by defining
        a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output
        dependencies.
    
        Args:
            logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
            targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
            logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
            target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
            blank (int, opt): blank label (Default: ``-1``)
            clamp (float): clamp for gradients (Default: ``-1``)
            runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``)
            fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
    j       reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``False``)
        """
        if not fused_log_softmax:
            logits = torch.nn.functional.log_softmax(logits, dim=-1)
            # softmax needs the original logits value
            assert not reuse_logits_for_grads, "fused_log_softmax=True requires reuse_logits_for_grads=False"
    
        if blank < 0:  # reinterpret blank index if blank < 0.
            blank = logits.shape[-1] + blank
    
>       costs, gradients = torch.ops.torchaudio.rnnt_loss(
            logits=logits,
            targets=targets,
            src_lengths=logit_lengths,
            tgt_lengths=target_lengths,
            blank=blank,
            clamp=clamp,
            fused_log_smax=fused_log_softmax,
            reuse_logits_for_grads=reuse_logits_for_grads,)
E       RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torchaudio/prototype/rnnt_loss.py:46: RuntimeError
___________ TestRNNTLoss.test_costs_and_gradients_B1_T2_U3_D5_fp32_1 ___________

a = (<torchaudio_unittest.rnnt.rnnt_loss_cuda_test.TestRNNTLoss testMethod=test_costs_and_gradients_B1_T2_U3_D5_fp32_1>,)

    @wraps(func)
    def standalone_func(*a):
>       return func(*(a + p.args), **p.kwargs)

../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/parameterized/parameterized.py:533: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
rnnt_loss_impl.py:55: in test_costs_and_gradients_B1_T2_U3_D5_fp32
    self._test_costs_and_gradients(
rnnt_loss_impl.py:22: in _test_costs_and_gradients
    costs, gradients = compute_with_pytorch_transducer(data=data)
utils.py:30: in compute_with_pytorch_transducer
    costs = RNNTLoss(
../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torch/nn/modules/module.py:1051: in _call_impl
    return forward_call(*input, **kwargs)
../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torchaudio/prototype/rnnt_loss.py:101: in forward
    return rnnt_loss(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

logits = tensor([[[[0.1000, 0.6000, 0.1000, 0.1000, 0.1000],
          [0.1000, 0.1000, 0.6000, 0.1000, 0.1000],
          [0.1...      [0.1000, 0.1000, 0.2000, 0.1000, 0.1000],
          [0.7000, 0.1000, 0.2000, 0.1000, 0.1000]]]], device='cuda:0')
targets = tensor([[1, 2]], device='cuda:0', dtype=torch.int32)
logit_lengths = tensor([2], device='cuda:0', dtype=torch.int32)
target_lengths = tensor([2], device='cuda:0', dtype=torch.int32), blank = 4
clamp = -1.0, fused_log_softmax = True, reuse_logits_for_grads = True

    def rnnt_loss(
        logits: Tensor,
        targets: Tensor,
        logit_lengths: Tensor,
        target_lengths: Tensor,
        blank: int = -1,
        clamp: float = -1,
        fused_log_softmax: bool = True,
        reuse_logits_for_grads: bool = False,
    ):
        """
        Compute the RNN Transducer Loss.
    
        The RNN Transducer loss (`Graves 2012 <https://arxiv.org/pdf/1211.3711.pdf>`__) extends the CTC loss by defining
        a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output
        dependencies.
    
        Args:
            logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
            targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
            logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
            target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
            blank (int, opt): blank label (Default: ``-1``)
            clamp (float): clamp for gradients (Default: ``-1``)
            runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``)
            fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
    j       reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``False``)
        """
        if not fused_log_softmax:
            logits = torch.nn.functional.log_softmax(logits, dim=-1)
            # softmax needs the original logits value
            assert not reuse_logits_for_grads, "fused_log_softmax=True requires reuse_logits_for_grads=False"
    
        if blank < 0:  # reinterpret blank index if blank < 0.
            blank = logits.shape[-1] + blank
    
>       costs, gradients = torch.ops.torchaudio.rnnt_loss(
            logits=logits,
            targets=targets,
            src_lengths=logit_lengths,
            tgt_lengths=target_lengths,
            blank=blank,
            clamp=clamp,
            fused_log_smax=fused_log_softmax,
            reuse_logits_for_grads=reuse_logits_for_grads,)
E       RuntimeError: cannot call get_autograd_meta() on undefined tensor

../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torchaudio/prototype/rnnt_loss.py:46: RuntimeError
___________ TestRNNTLoss.test_costs_and_gradients_B2_T4_U3_D3_fp16_1 ___________

a = (<torchaudio_unittest.rnnt.rnnt_loss_cuda_test.TestRNNTLoss testMethod=test_costs_and_gradients_B2_T4_U3_D3_fp16_1>,)

    @wraps(func)
    def standalone_func(*a):
>       return func(*(a + p.args), **p.kwargs)

../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/parameterized/parameterized.py:533: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
rnnt_loss_impl.py:101: in test_costs_and_gradients_B2_T4_U3_D3_fp16
    self._test_costs_and_gradients(
rnnt_loss_impl.py:22: in _test_costs_and_gradients
    costs, gradients = compute_with_pytorch_transducer(data=data)
utils.py:30: in compute_with_pytorch_transducer
    costs = RNNTLoss(
../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torch/nn/modules/module.py:1051: in _call_impl
    return forward_call(*input, **kwargs)
../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torchaudio/prototype/rnnt_loss.py:101: in forward
    return rnnt_loss(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

logits = tensor([[[[-1.8677e-01, -6.2469e-02,  2.4939e-01],
          [-2.0337e-01,  2.0239e-01,  9.8896e-04],
          [-1.40...
          [-5.9863e-01,  3.0225e-01,  2.9639e-01]]]], device='cuda:0',
       dtype=torch.float16, requires_grad=True)
targets = tensor([[1, 2],
        [1, 1]], device='cuda:0', dtype=torch.int32)
logit_lengths = tensor([4, 4], device='cuda:0', dtype=torch.int32)
target_lengths = tensor([2, 2], device='cuda:0', dtype=torch.int32), blank = 0
clamp = -1.0, fused_log_softmax = True, reuse_logits_for_grads = True

    def rnnt_loss(
        logits: Tensor,
        targets: Tensor,
        logit_lengths: Tensor,
        target_lengths: Tensor,
        blank: int = -1,
        clamp: float = -1,
        fused_log_softmax: bool = True,
        reuse_logits_for_grads: bool = False,
    ):
        """
        Compute the RNN Transducer Loss.
    
        The RNN Transducer loss (`Graves 2012 <https://arxiv.org/pdf/1211.3711.pdf>`__) extends the CTC loss by defining
        a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output
        dependencies.
    
        Args:
            logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
            targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
            logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
            target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
            blank (int, opt): blank label (Default: ``-1``)
            clamp (float): clamp for gradients (Default: ``-1``)
            runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``)
            fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
    j       reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``False``)
        """
        if not fused_log_softmax:
            logits = torch.nn.functional.log_softmax(logits, dim=-1)
            # softmax needs the original logits value
            assert not reuse_logits_for_grads, "fused_log_softmax=True requires reuse_logits_for_grads=False"
    
        if blank < 0:  # reinterpret blank index if blank < 0.
            blank = logits.shape[-1] + blank
    
>       costs, gradients = torch.ops.torchaudio.rnnt_loss(
            logits=logits,
            targets=targets,
            src_lengths=logit_lengths,
            tgt_lengths=target_lengths,
            blank=blank,
            clamp=clamp,
            fused_log_smax=fused_log_softmax,
            reuse_logits_for_grads=reuse_logits_for_grads,)
E       RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torchaudio/prototype/rnnt_loss.py:46: RuntimeError
___________ TestRNNTLoss.test_costs_and_gradients_B2_T4_U3_D3_fp32_1 ___________

a = (<torchaudio_unittest.rnnt.rnnt_loss_cuda_test.TestRNNTLoss testMethod=test_costs_and_gradients_B2_T4_U3_D3_fp32_1>,)

    @wraps(func)
    def standalone_func(*a):
>       return func(*(a + p.args), **p.kwargs)

../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/parameterized/parameterized.py:533: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
rnnt_loss_impl.py:87: in test_costs_and_gradients_B2_T4_U3_D3_fp32
    self._test_costs_and_gradients(
rnnt_loss_impl.py:22: in _test_costs_and_gradients
    costs, gradients = compute_with_pytorch_transducer(data=data)
utils.py:30: in compute_with_pytorch_transducer
    costs = RNNTLoss(
../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torch/nn/modules/module.py:1051: in _call_impl
    return forward_call(*input, **kwargs)
../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torchaudio/prototype/rnnt_loss.py:101: in forward
    return rnnt_loss(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

logits = tensor([[[[-1.8684e-01, -6.2555e-02,  2.4940e-01],
          [-2.0338e-01,  2.0240e-01,  9.7743e-04],
          [-1.41...79e-01,  1.1646e-01],
          [-5.9869e-01,  3.0220e-01,  2.9648e-01]]]], device='cuda:0',
       requires_grad=True)
targets = tensor([[1, 2],
        [1, 1]], device='cuda:0', dtype=torch.int32)
logit_lengths = tensor([4, 4], device='cuda:0', dtype=torch.int32)
target_lengths = tensor([2, 2], device='cuda:0', dtype=torch.int32), blank = 0
clamp = -1.0, fused_log_softmax = True, reuse_logits_for_grads = True

    def rnnt_loss(
        logits: Tensor,
        targets: Tensor,
        logit_lengths: Tensor,
        target_lengths: Tensor,
        blank: int = -1,
        clamp: float = -1,
        fused_log_softmax: bool = True,
        reuse_logits_for_grads: bool = False,
    ):
        """
        Compute the RNN Transducer Loss.
    
        The RNN Transducer loss (`Graves 2012 <https://arxiv.org/pdf/1211.3711.pdf>`__) extends the CTC loss by defining
        a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output
        dependencies.
    
        Args:
            logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
            targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
            logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
            target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
            blank (int, opt): blank label (Default: ``-1``)
            clamp (float): clamp for gradients (Default: ``-1``)
            runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``)
            fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
    j       reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``False``)
        """
        if not fused_log_softmax:
            logits = torch.nn.functional.log_softmax(logits, dim=-1)
            # softmax needs the original logits value
            assert not reuse_logits_for_grads, "fused_log_softmax=True requires reuse_logits_for_grads=False"
    
        if blank < 0:  # reinterpret blank index if blank < 0.
            blank = logits.shape[-1] + blank
    
>       costs, gradients = torch.ops.torchaudio.rnnt_loss(
            logits=logits,
            targets=targets,
            src_lengths=logit_lengths,
            tgt_lengths=target_lengths,
            blank=blank,
            clamp=clamp,
            fused_log_smax=fused_log_softmax,
            reuse_logits_for_grads=reuse_logits_for_grads,)
E       RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torchaudio/prototype/rnnt_loss.py:46: RuntimeError
_____ TestRNNTLoss.test_costs_and_gradients_random_data_with_numpy_fp32_1 ______

a = (<torchaudio_unittest.rnnt.rnnt_loss_cuda_test.TestRNNTLoss testMethod=test_costs_and_gradients_random_data_with_numpy_fp32_1>,)

    @wraps(func)
    def standalone_func(*a):
>       return func(*(a + p.args), **p.kwargs)

../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/parameterized/parameterized.py:533: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
rnnt_loss_impl.py:120: in test_costs_and_gradients_random_data_with_numpy_fp32
    self._test_costs_and_gradients(
rnnt_loss_impl.py:22: in _test_costs_and_gradients
    costs, gradients = compute_with_pytorch_transducer(data=data)
utils.py:30: in compute_with_pytorch_transducer
    costs = RNNTLoss(
../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torch/nn/modules/module.py:1051: in _call_impl
    return forward_call(*input, **kwargs)
../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torchaudio/prototype/rnnt_loss.py:101: in forward
    return rnnt_loss(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

logits = tensor([[[[ 1.4938e-01,  7.5933e-02, -3.4599e-01,  ...,  6.9515e-02,
            1.0250e-01, -3.5009e-01],
          [...0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00]]]], device='cuda:0', requires_grad=True)
targets = tensor([[2, 4, 5, 7, 1, 7, 3, 7, 2, 2, 2, 6, 7, 5, 4, 0, 6, 2, 2],
        [4, 7, 7, 0, 0, 0, 3, 2, 2, 0, 3, 5, 3, 4, ..., 1, 4],
        [5, 0, 5, 0, 7, 0, 6, 5, 7, 1, 6, 6, 5, 4, 7, 1, 0, 1, 4]],
       device='cuda:0', dtype=torch.int32)
logit_lengths = tensor([34, 30, 29, 23], device='cuda:0', dtype=torch.int32)
target_lengths = tensor([18, 19,  5,  6], device='cuda:0', dtype=torch.int32)
blank = 8, clamp = -1.0, fused_log_softmax = True, reuse_logits_for_grads = True

    def rnnt_loss(
        logits: Tensor,
        targets: Tensor,
        logit_lengths: Tensor,
        target_lengths: Tensor,
        blank: int = -1,
        clamp: float = -1,
        fused_log_softmax: bool = True,
        reuse_logits_for_grads: bool = False,
    ):
        """
        Compute the RNN Transducer Loss.
    
        The RNN Transducer loss (`Graves 2012 <https://arxiv.org/pdf/1211.3711.pdf>`__) extends the CTC loss by defining
        a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output
        dependencies.
    
        Args:
            logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
            targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
            logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
            target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
            blank (int, opt): blank label (Default: ``-1``)
            clamp (float): clamp for gradients (Default: ``-1``)
            runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``)
            fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
    j       reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``False``)
        """
        if not fused_log_softmax:
            logits = torch.nn.functional.log_softmax(logits, dim=-1)
            # softmax needs the original logits value
            assert not reuse_logits_for_grads, "fused_log_softmax=True requires reuse_logits_for_grads=False"
    
        if blank < 0:  # reinterpret blank index if blank < 0.
            blank = logits.shape[-1] + blank
    
>       costs, gradients = torch.ops.torchaudio.rnnt_loss(
            logits=logits,
            targets=targets,
            src_lengths=logit_lengths,
            tgt_lengths=target_lengths,
            blank=blank,
            clamp=clamp,
            fused_log_smax=fused_log_softmax,
            reuse_logits_for_grads=reuse_logits_for_grads,)
E       RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

../../../../../miniconda/envs/torch-nightly/lib/python3.8/site-packages/torchaudio/prototype/rnnt_loss.py:46: RuntimeError
=============================== warnings summary ===============================
test/torchaudio_unittest/rnnt/autograd_cpu_test.py: 6 warnings
test/torchaudio_unittest/rnnt/autograd_cuda_test.py: 6 warnings
  /private/home/vincentqb/miniconda/envs/torch-nightly/lib/python3.8/site-packages/torch/autograd/gradcheck.py:635: UserWarning: Input #0 requires gradient and is not a double precision floating point or complex. This check will likely fail if all the inputs are not of double precision floating point or complex. 
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/warnings.html
=========================== short test summary info ============================
FAILED rnnt_loss_cpu_test.py::TestRNNTLoss::test_costs_and_gradients_B1_T2_U3_D5_fp16_1
FAILED rnnt_loss_cpu_test.py::TestRNNTLoss::test_costs_and_gradients_B1_T2_U3_D5_fp32_1
FAILED rnnt_loss_cpu_test.py::TestRNNTLoss::test_costs_and_gradients_B2_T4_U3_D3_fp16_1
FAILED rnnt_loss_cpu_test.py::TestRNNTLoss::test_costs_and_gradients_B2_T4_U3_D3_fp32_1
FAILED rnnt_loss_cpu_test.py::TestRNNTLoss::test_costs_and_gradients_random_data_with_numpy_fp32_1
FAILED rnnt_loss_cuda_test.py::TestRNNTLoss::test_costs_and_gradients_B1_T2_U3_D5_fp16_1
FAILED rnnt_loss_cuda_test.py::TestRNNTLoss::test_costs_and_gradients_B1_T2_U3_D5_fp32_1
FAILED rnnt_loss_cuda_test.py::TestRNNTLoss::test_costs_and_gradients_B2_T4_U3_D3_fp16_1
FAILED rnnt_loss_cuda_test.py::TestRNNTLoss::test_costs_and_gradients_B2_T4_U3_D3_fp32_1
FAILED rnnt_loss_cuda_test.py::TestRNNTLoss::test_costs_and_gradients_random_data_with_numpy_fp32_1
================== 10 failed, 30 passed, 12 warnings in 6.13s ==================

We'll remove the option reuse_logits_for_grads for now.

@vincentqb vincentqb force-pushed the rnntautograd branch 3 times, most recently from 484c217 to 3425bac Compare June 3, 2021 21:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant