TensorFlow: slow performance when getting gradients at inputs

The tf.gradients() function builds a new backpropagation graph each time it is called, so the reason for the slowdown is that TensorFlow has to parse a new graph on each iteration of the loop. (This can be surprisingly expensive: the current version of TensorFlow is optimized for executing the same graph a large number of times.)

Fortunately the solution is easy: just compute the gradients once, outside the loop. You can restructure your code as follows:

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.network, self.y))
optimizer = tf.train.AdagradOptimizer(learning_rate=nn_learning_rate).minimize(cost)
grads_wrt_input_tensor = tf.gradients(cost, self.x)[0]
# ...
for i in range(epochs):
    # ...
    for batch in batches:
        # ...
        _, grads_wrt_input = sess.run([optimizer, grads_wrt_input_tensor],
                                      feed_dict=feed_dict)

Note that, for performance, I also combined the two sess.run() calls. This ensures that the forward propagation, and much of the backpropagation, will be reused.


As an aside, one tip to find performance bugs like this is to call tf.get_default_graph().finalize() before starting your training loop. This will raise an exception if you inadvertantly add any nodes to the graph, which makes it easier to trace the cause of these bugs.

Leave a Comment