diff --git a/src/model.py b/src/model.py index 230b83cc2..e4eb138c0 100644 --- a/src/model.py +++ b/src/model.py @@ -17,7 +17,7 @@ def shape_list(x): dynamic = tf.shape(x) return [dynamic[i] if s is None else s for i, s in enumerate(static)] -def softmax(x, axis=-1): +def softmax(x, axis=-1):"""Normalises the probability function to prevent the overflow of all elements""" x = x - tf.reduce_max(x, axis=axis, keepdims=True) ex = tf.exp(x) return ex / tf.reduce_sum(ex, axis=axis, keepdims=True)