Save and load model optimizer state

You can extract the important lines from the load_model and save_model functions.

For saving optimizer states, in save_model:

# Save optimizer weights.
symbolic_weights = getattr(model.optimizer, 'weights')
if symbolic_weights:
    optimizer_weights_group = f.create_group('optimizer_weights')
    weight_values = K.batch_get_value(symbolic_weights)

For loading optimizer states, in load_model:

# Set optimizer weights.
if 'optimizer_weights' in f:
    # Build train function (to get weight updates).
    if isinstance(model, Sequential):
        model.model._make_train_function()
    else:
        model._make_train_function()

    # ...

    try:
        model.optimizer.set_weights(optimizer_weight_values)

Combining the lines above, here’s an example:

  1. First fit the model for 5 epochs.
X, y = np.random.rand(100, 50), np.random.randint(2, size=100)
x = Input((50,))
out = Dense(1, activation='sigmoid')(x)
model = Model(x, out)
model.compile(optimizer="adam", loss="binary_crossentropy")
model.fit(X, y, epochs=5)

Epoch 1/5
100/100 [==============================] - 0s 4ms/step - loss: 0.7716
Epoch 2/5
100/100 [==============================] - 0s 64us/step - loss: 0.7678
Epoch 3/5
100/100 [==============================] - 0s 82us/step - loss: 0.7665
Epoch 4/5
100/100 [==============================] - 0s 56us/step - loss: 0.7647
Epoch 5/5
100/100 [==============================] - 0s 76us/step - loss: 0.7638
  1. Now save the weights and optimizer states.
model.save_weights('weights.h5')
symbolic_weights = getattr(model.optimizer, 'weights')
weight_values = K.batch_get_value(symbolic_weights)
with open('optimizer.pkl', 'wb') as f:
    pickle.dump(weight_values, f)
  1. Rebuild the model in another python session, and load weights.
x = Input((50,))
out = Dense(1, activation='sigmoid')(x)
model = Model(x, out)
model.compile(optimizer="adam", loss="binary_crossentropy")

model.load_weights('weights.h5')
model._make_train_function()
with open('optimizer.pkl', 'rb') as f:
    weight_values = pickle.load(f)
model.optimizer.set_weights(weight_values)
  1. Continue model training.
model.fit(X, y, epochs=5)

Epoch 1/5
100/100 [==============================] - 0s 674us/step - loss: 0.7629
Epoch 2/5
100/100 [==============================] - 0s 49us/step - loss: 0.7617
Epoch 3/5
100/100 [==============================] - 0s 49us/step - loss: 0.7611
Epoch 4/5
100/100 [==============================] - 0s 55us/step - loss: 0.7601
Epoch 5/5
100/100 [==============================] - 0s 49us/step - loss: 0.7594

Leave a Comment