Skip to content

Commit 6f0221a

Browse files
committed
fix numpy backward: be careful to not modify inplace.
1 parent cc0acac commit 6f0221a

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

test/torchaudio_unittest/rnnt/numpy_transducer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ def forward(
3333
return costs
3434

3535
@staticmethod
36-
def backward(ctx, output_gradients):
37-
return ctx.grads, None, None, None, None, None, None, None, None
36+
def backward(ctx, grad_output):
37+
grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads)
38+
return ctx.grads.mul(grad_output), None, None, None, None, None, None, None, None
3839

3940
@staticmethod
4041
def compute_alpha_one_sequence(log_probs, targets, blank=-1):

0 commit comments

Comments
 (0)