Skip to content

Commit 345b64d

Browse files
mattdangerwsampathweb
authored andcommitted
Fix bug when upranking passthrough inputs to RandAugment (#2194)
- RandAugment sometimes will choose a "no augmentation" option and passthrough inputs unaltered. - Preprocessing normalization routines were not making copies of inputs and sometimes mutating layer input directly (mutating the input dict to cast dtypes and uprank tensors). - RandAugment under the passthrough option would return these inputs directly. The net effect was sometimes attempting to uprank during a passthrough call, breaking tf.map_fn
1 parent 2488855 commit 345b64d

File tree

4 files changed

+17
-4
lines changed

4 files changed

+17
-4
lines changed

keras_cv/layers/preprocessing/base_image_augmentation_layer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,8 @@ def _ensure_inputs_are_compute_dtype(self, inputs):
571571
inputs,
572572
self.compute_dtype,
573573
)
574+
# Copy the input dict before we mutate it.
575+
inputs = dict(inputs)
574576
inputs[IMAGES] = preprocessing.ensure_tensor(
575577
inputs[IMAGES],
576578
self.compute_dtype,

keras_cv/layers/preprocessing/rand_augment_test.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import numpy as np
15-
import pytest
1615
import tensorflow as tf
1716
from absl.testing import parameterized
1817

1918
from keras_cv import layers
20-
from keras_cv.backend.config import keras_3
2119
from keras_cv.tests.test_case import TestCase
2220

2321

24-
@pytest.mark.skipif(keras_3(), reason="imcompatible with Keras 3")
2522
class RandAugmentTest(TestCase):
23+
def test_zero_rate_pass_through(self):
24+
rand_augment = layers.RandAugment(
25+
value_range=(0, 255),
26+
rate=0.0,
27+
)
28+
xs = np.ones((2, 512, 512, 3))
29+
ys = rand_augment(xs)
30+
self.assertAllClose(ys, xs)
31+
2632
@parameterized.named_parameters(
2733
("0", 0),
2834
("20", 0.2),

keras_cv/layers/preprocessing/random_augmentation_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _augment(self, inputs):
103103
)
104104
result = tf.cond(
105105
skip_augment > self.rate,
106-
lambda: inputs,
106+
lambda: result,
107107
lambda: self._random_choice(result),
108108
)
109109
return result

keras_cv/layers/preprocessing/vectorized_base_image_augmentation_layer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,9 @@ def _format_inputs(self, inputs):
444444
# single image input tensor
445445
metadata[IS_DICT] = False
446446
inputs = {IMAGES: inputs}
447+
else:
448+
# Copy the input dict before we mutate it.
449+
inputs = dict(inputs)
447450

448451
metadata[BATCHED] = inputs["images"].shape.rank == 4
449452
if inputs["images"].shape.rank == 3:
@@ -504,6 +507,8 @@ def _ensure_inputs_are_compute_dtype(self, inputs):
504507
inputs,
505508
self.compute_dtype,
506509
)
510+
# Copy the input dict before we mutate it.
511+
inputs = dict(inputs)
507512
inputs[IMAGES] = preprocessing.ensure_tensor(
508513
inputs[IMAGES],
509514
self.compute_dtype,

0 commit comments

Comments
 (0)