Skip to content

Exception in thread "main" org.tensorflow.exceptions.TensorFlowException: No gradient defined for op: Concat #156

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
aday00 opened this issue Nov 23, 2020 · 0 comments

Comments

@aday00
Copy link

aday00 commented Nov 23, 2020

Concat doesn't have gradients defined, so it's not possible to train with Concat in a deep net, per some discussion on the mailing list, e.g. https://groups.google.com/a/tensorflow.org/g/jvm/c/TTuT3yzoKWs/m/pTQX1w_XAgAJ

Exception in thread "main" org.tensorflow.exceptions.TensorFlowException: No gradient defined for op: Concat. Please see https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md for instructions on how to add C++ gradients.
        at org.tensorflow.internal.c_api.AbstractTF_Status.throwExceptionIfNotOK(AbstractTF_Status.java:101)
        at org.tensorflow.Graph.addGradients(Graph.java:649)
        at org.tensorflow.Graph.addGradients(Graph.java:267)
        at org.tensorflow.Graph.addGradients(Graph.java:301)
        at org.tensorflow.framework.optimizers.Optimizer.computeGradients(Optimizer.java:113)
        at org.tensorflow.framework.optimizers.Optimizer.minimize(Optimizer.java:94)
        at org.tensorflow.framework.optimizers.Optimizer.minimize(Optimizer.java:90)
        ...

However, because zero padding and add both have gradients defined, my current workaround is to make a "fake concat", which zero pads the two vectors to concat, then adds them together. In Scala, this is:

val padded1 = tf.withName("padded1").pad(some_input_of_512_dimensions,
                                         tf.constant(Array(Array(0,0), Array(1024,0))), tf.constant(0.0f)) 
val padded2 = tf.withName("padded2").pad(some_other_input_of_1024_dimensions,
                                         tf.constant(Array(Array(0,0), Array(0,512))), tf.constant(0.0f))
val fake_concat = tf.withName("fake_concat").math.add(padded1, padded2) // add these two vectors together, effectively concat'ing them.  tf.concat(...) doesn't itself have gradients implemented, so doesn't work during training.

Concat's useful for BatchNorm etc, so I mentioned this in #135 (comment) and the fake concat pseudocode is at https://groups.google.com/a/tensorflow.org/g/jvm/c/TTuT3yzoKWs/m/pTQX1w_XAgAJ

Models with fake concat train and save fine, but model loading can be problematic (and would be a separate ticket). Hopefully this ticket is useful to track adding gradients to Concat.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant