Get the value of some weights in a model trained by TensorFlow

In TensorFlow, trained weights are represented by tf.Variable objects. If you created a tf.Variable—e.g. called v—yourself, you can get its value as a NumPy array by calling sess.run(v) (where sess is a tf.Session).

If you do not currently have a pointer to the tf.Variable, you can get a list of the trainable variables in the current graph by calling tf.trainable_variables(). This function returns a list of all trainable tf.Variable objects in the current graph, and you can select the one that you want by matching the v.name property. For example:

# Desired variable is called "tower_2/filter:0".
var = [v for v in tf.trainable_variables() if v.name == "tower_2/filter:0"][0]

Leave a Comment