Skip to content

Commit e25de5e

Browse files
authored
Update base image aug layer tensor conversion (#2281)
1 parent 72d6120 commit e25de5e

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

keras_cv/layers/preprocessing/base_image_augmentation_layer.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from keras_cv import bounding_box
2525
from keras_cv.api_export import keras_cv_export
26+
from keras_cv.backend import config
2627
from keras_cv.backend import keras
2728
from keras_cv.backend import ops
2829
from keras_cv.backend import scope
@@ -411,14 +412,16 @@ def get_random_transformation(
411412
def call(self, inputs):
412413
# try to convert a given backend native tensor to TensorFlow tensor
413414
# before passing it over to TFDataScope
415+
is_tf_backend = config.backend() == "tensorflow"
416+
is_in_tf_graph = not tf.executing_eagerly()
414417
contains_ragged = lambda y: any(
415418
tree.map_structure(
416419
lambda x: isinstance(x, (tf.RaggedTensor, tf.SparseTensor)),
417420
tree.flatten(y),
418421
)
419422
)
420423
inputs_contain_ragged = contains_ragged(inputs)
421-
if not inputs_contain_ragged:
424+
if not is_tf_backend and not inputs_contain_ragged:
422425
inputs = tree.map_structure(
423426
lambda x: tf.convert_to_tensor(x), inputs
424427
)
@@ -444,13 +447,15 @@ def call(self, inputs):
444447
# backend native tensors. This is to avoid breaking TF data
445448
# pipelines that can't easily be ported to become backend
446449
# agnostic.
447-
if not inputs_contain_ragged and not contains_ragged(outputs):
448-
outputs = tree.map_structure(
449-
# some layers return None, handle that case when
450-
# converting to tensors
451-
lambda x: ops.convert_to_tensor(x) if x is not None else x,
452-
outputs,
453-
)
450+
# Skip this step for TF backend or if in `tf.graph` like `tf.data`.
451+
if not is_tf_backend and not is_in_tf_graph:
452+
if not inputs_contain_ragged and not contains_ragged(outputs):
453+
outputs = tree.map_structure(
454+
# some layers return None, handle that case when
455+
# converting to tensors
456+
lambda x: ops.convert_to_tensor(x) if x is not None else x,
457+
outputs,
458+
)
454459
return outputs
455460

456461
def _augment(self, inputs):

keras_cv/layers/preprocessing/base_image_augmentation_layer_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,11 @@ def in_tf_function(inputs):
265265
self.assertNotAllClose(
266266
segmentation_mask_diff[0], segmentation_mask_diff[1]
267267
)
268+
269+
def test_augment_tf_data_pipeline(self):
270+
image = np.random.random(size=(1, 8, 8, 3)).astype("float32")
271+
tf_dataset = tf.data.Dataset.from_tensor_slices(image).map(
272+
RandomAddLayer(fixed_value=2.0)
273+
)
274+
output = iter(tf_dataset).get_next()
275+
self.assertAllClose(image[0] + 2.0, output)

0 commit comments

Comments
 (0)