Can’t save custom subclassed model

TensorFlow 2.2

Thanks for @cal for noticing me that the new TensorFlow has supported saving the custom models!

By using model.save to save the whole model and by using load_model to restore previously stored subclassed model. The following code snippets describe how to implement them.

class ThreeLayerMLP(keras.Model):

  def __init__(self, name=None):
    super(ThreeLayerMLP, self).__init__(name=name)
    self.dense_1 = layers.Dense(64, activation='relu', name="dense_1")
    self.dense_2 = layers.Dense(64, activation='relu', name="dense_2")
    self.pred_layer = layers.Dense(10, name="predictions")

  def call(self, inputs):
    x = self.dense_1(inputs)
    x = self.dense_2(x)
    return self.pred_layer(x)

def get_model():
  return ThreeLayerMLP(name="3_layer_mlp")

model = get_model()
# Save the model
model.save('path_to_my_model',save_format="tf")

# Recreate the exact same model purely from the file
new_model = keras.models.load_model('path_to_my_model')

See: Save and serialize models with Keras – Part II: Saving and Loading of Subclassed Models

TensorFlow 2.0

TL;DR:

  1. do not use model.save() for custom subclass keras model;
  2. use save_weights() and load_weights() instead.

With the help of the Tensorflow Team, it turns out the best practice of saving a Custom Sub-Class Keras Model is to save its weights and load it back when needed.

The reason that we can not simply save a Keras custom subclass model is that it contains custom codes, which can not be serialized safely. However, the weights can be saved/loaded when we have the same model structure and custom codes without any problem.

There has a great tutorial written by Francois Chollet who is the author of Keras, for how to save/load Sequential/Functional/Keras/Custom Sub-Class Models in Tensorflow 2.0 in Colab at here. In Saving Subclassed Models section, it said that:

Sequential models and Functional models are datastructures that represent a DAG of layers. As such, they can be safely serialized and deserialized.

A subclassed model differs in that it’s not a datastructure, it’s a
piece of code. The architecture of the model is defined via the body
of the call method. This means that the architecture of the model
cannot be safely serialized. To load a model, you’ll need to have
access to the code that created it (the code of the model subclass).
Alternatively, you could be serializing this code as bytecode (e.g.
via pickling), but that’s unsafe and generally not portable.

Leave a Comment