Skip to content

Commit e0c4415

Browse files
Update RNN tests and implementation
1 parent 4953583 commit e0c4415

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

keras/src/layers/rnn/rnn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,12 @@ def inner_loop(self, sequences, initial_state, mask, training=False):
331331
cell_kwargs["training"] = training
332332

333333
def step(inputs, states):
334+
"""
335+
Create new tensor copies when using PyTorch backend
336+
with stateful=True. This prevents in-place modifications
337+
that would otherwise break PyTorch's autograd functionality
338+
by modifying tensors needed for gradient computation.
339+
"""
334340
if backend.backend() == "torch" and self.stateful:
335341
states = tree.map_structure(ops.copy, states)
336342
output, new_states = self.cell(inputs, states, **cell_kwargs)

keras/src/layers/rnn/rnn_test.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -385,21 +385,17 @@ def test_serialization(self):
385385

386386
@pytest.mark.torch
387387
def test_stateful_state_copying(self):
388-
"""Test that states are properly copied (not referenced) in PyTorch backend when stateful=True."""
389388
sequence = np.ones((1, 2, 3))
390389
layer = layers.RNN(
391390
TwoStatesRNNCell(2),
392391
stateful=True,
393392
return_state=True,
394393
)
395-
396-
# Single forward pass
397394
_, state1, state2 = layer(sequence)
398-
399-
# Check that layer.states contains clones of the states, not references
400-
self.assertIsNot(state1, layer.states[0]) # Should be different objects
395+
396+
self.assertIsNot(state1, layer.states[0])
401397
self.assertIsNot(state2, layer.states[1])
402-
self.assertAllClose(state1, layer.states[0]) # but with same values
398+
self.assertAllClose(state1, layer.states[0])
403399
self.assertAllClose(state2, layer.states[1])
404400

405401
# TODO: test masking

0 commit comments

Comments
 (0)