23
23
24
24
from keras_cv import bounding_box
25
25
from keras_cv .api_export import keras_cv_export
26
+ from keras_cv .backend import config
26
27
from keras_cv .backend import keras
27
28
from keras_cv .backend import ops
28
29
from keras_cv .backend import scope
@@ -411,14 +412,16 @@ def get_random_transformation(
411
412
def call (self , inputs ):
412
413
# try to convert a given backend native tensor to TensorFlow tensor
413
414
# before passing it over to TFDataScope
415
+ is_tf_backend = config .backend () == "tensorflow"
416
+ is_in_tf_graph = not tf .executing_eagerly ()
414
417
contains_ragged = lambda y : any (
415
418
tree .map_structure (
416
419
lambda x : isinstance (x , (tf .RaggedTensor , tf .SparseTensor )),
417
420
tree .flatten (y ),
418
421
)
419
422
)
420
423
inputs_contain_ragged = contains_ragged (inputs )
421
- if not inputs_contain_ragged :
424
+ if not is_tf_backend and not inputs_contain_ragged :
422
425
inputs = tree .map_structure (
423
426
lambda x : tf .convert_to_tensor (x ), inputs
424
427
)
@@ -444,13 +447,15 @@ def call(self, inputs):
444
447
# backend native tensors. This is to avoid breaking TF data
445
448
# pipelines that can't easily be ported to become backend
446
449
# agnostic.
447
- if not inputs_contain_ragged and not contains_ragged (outputs ):
448
- outputs = tree .map_structure (
449
- # some layers return None, handle that case when
450
- # converting to tensors
451
- lambda x : ops .convert_to_tensor (x ) if x is not None else x ,
452
- outputs ,
453
- )
450
+ # Skip this step for TF backend or if in `tf.graph` like `tf.data`.
451
+ if not is_tf_backend and not is_in_tf_graph :
452
+ if not inputs_contain_ragged and not contains_ragged (outputs ):
453
+ outputs = tree .map_structure (
454
+ # some layers return None, handle that case when
455
+ # converting to tensors
456
+ lambda x : ops .convert_to_tensor (x ) if x is not None else x ,
457
+ outputs ,
458
+ )
454
459
return outputs
455
460
456
461
def _augment (self , inputs ):
0 commit comments