How can I convert a trained Tensorflow model to Keras?

I think the callback in keras is also a solution.

The ckpt file can be saved by TF with:

saver = tf.train.Saver()
saver.save(sess, checkpoint_name)

and to load checkpoint in Keras, you need a callback class as follow:

class RestoreCkptCallback(keras.callbacks.Callback):
    def __init__(self, pretrained_file):
        self.pretrained_file = pretrained_file
        self.sess = keras.backend.get_session()
        self.saver = tf.train.Saver()
    def on_train_begin(self, logs=None):
        if self.pretrian_model_path:
            self.saver.restore(self.sess, self.pretrian_model_path)
            print('load weights: OK.')

Then in your keras script:

 model.compile(loss="categorical_crossentropy", optimizer="rmsprop")
 restore_ckpt_callback = RestoreCkptCallback(pretrian_model_path="./XXXX.ckpt") 
 model.fit(x_train, y_train, batch_size=128, epochs=20, callbacks=[restore_ckpt_callback])

That will be fine.
I think it is easy to implement and hope it helps.

Leave a Comment