Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions keras_cv/layers/preprocessing/base_image_augmentation_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from keras_cv import bounding_box
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import config
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.backend import scope
Expand Down Expand Up @@ -411,14 +412,16 @@ def get_random_transformation(
def call(self, inputs):
# try to convert a given backend native tensor to TensorFlow tensor
# before passing it over to TFDataScope
is_tf_backend = config.backend() == "tensorflow"
is_in_tf_graph = not tf.executing_eagerly()
contains_ragged = lambda y: any(
tree.map_structure(
lambda x: isinstance(x, (tf.RaggedTensor, tf.SparseTensor)),
tree.flatten(y),
)
)
inputs_contain_ragged = contains_ragged(inputs)
if not inputs_contain_ragged:
if not is_tf_backend and not inputs_contain_ragged:
inputs = tree.map_structure(
lambda x: tf.convert_to_tensor(x), inputs
)
Expand All @@ -444,13 +447,15 @@ def call(self, inputs):
# backend native tensors. This is to avoid breaking TF data
# pipelines that can't easily be ported to become backend
# agnostic.
if not inputs_contain_ragged and not contains_ragged(outputs):
outputs = tree.map_structure(
# some layers return None, handle that case when
# converting to tensors
lambda x: ops.convert_to_tensor(x) if x is not None else x,
outputs,
)
# Skip this step for TF backend or if in `tf.graph` like `tf.data`.
if not is_tf_backend and not is_in_tf_graph:
if not inputs_contain_ragged and not contains_ragged(outputs):
outputs = tree.map_structure(
# some layers return None, handle that case when
# converting to tensors
lambda x: ops.convert_to_tensor(x) if x is not None else x,
outputs,
)
return outputs

def _augment(self, inputs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,11 @@ def in_tf_function(inputs):
self.assertNotAllClose(
segmentation_mask_diff[0], segmentation_mask_diff[1]
)

def test_augment_tf_data_pipeline(self):
image = np.random.random(size=(1, 8, 8, 3)).astype("float32")
tf_dataset = tf.data.Dataset.from_tensor_slices(image).map(
RandomAddLayer(fixed_value=2.0)
)
output = iter(tf_dataset).get_next()
self.assertAllClose(image[0] + 2.0, output)