TensorFlow Save & Restore Model

Image for post
Image for post

Keras API provides built-in classes to save models regularly during model fitting. To save a model and restore it later, we can create a callback ModelCheckPoint passed to model.fit, and the model will be saved regularly.

Image for post
Image for post

In the example above, models are saved epoch. In the configuration below, save_best_only is True. Therefore, the model is saved only when validation loss is the lowest so far.

Image for post
Image for post

In the configuration above, a new checkpoint overwrites the old one because it uses the same checkpoint name. Here is another example where the checkpoint file includes the epoch number so it will not be overwritten.

Image for post
Image for post

Here are what saved under training_2 directory. The configuration above saves the model every 5 epochs.

Image for post
Image for post

Save the whole model v.s. weights only

There are two options to save a model — weights only or include the training states as well as the model architecture also. If save_weights_only flag is True in creating ModelCheckpoint, the model will be saved as model.save_weights(filepath). This saves the model weights only. If it is False, the full model is saved in the SavedModel format.

By default, a model will be saved every epoch. But it can be overridden with save_freq in ModelCheckpoint.

save_freq= int(NUM_OF_EPOCHS * STEPS_PER_EPOCH)

model.save_weights

If the model is saved with weights only, we need to instantiate a new model first before restoring the weights. Likely, we call the original Python code (create_model in our example) to create a model instance. Then we load the weights of the model with model.load_weights. The latest checkpoint can be located by tf.train.latest_checkpoint.

Image for post
Image for post

Without the ModelCheckpoint callback, we can call model.save_weights to save the model weights manually.

Image for post
Image for post

model.save

To save the complete model, we use model.save(filepath) to save it as a SavedModel. As later explained, it contains the state of the optimizer and the dataset iterator such that the whole training can be resumed from the last saved point. Since the model architecture and configuration are also saved, the model can be restored directly without creating a model instance.

Image for post
Image for post

When a model is saved, all the model’s tf.Variable are saved and all @tf.function annotated methods are also saved as a graph. Below is a model saved as dnn_model.

Image for post
Image for post

We don’t need the original Python code anymore. TF executes the graph directly. In fact, this reduces possible mistakes during production deployment. Below is what the directory my_model contains now:

Image for post
Image for post

But that requires all methods needed by any custom layers to be covered by @tf.function annotation.

Image for post
Image for post

CheckpointManager

We can also use the CheckpointManager to save models if we want to use the lower level Keras API. The code below is the boilerplate code for creating a toy dataset and a model. It also contains codes for the training step.

Image for post
Image for post

To save a checkpoint, we create a CheckpointManager with a Checkpoint. This Checkpoint contains the model, the optimizer, training state (step), and the dataset iterator. So before the training starts, we can restore the checkpoint with the latest stored checkpoint. This loads the model weights and restores the state of the optimizer, the dataset iterator, and the training steps. In short, we resume the training state when the model is last saved — not just the model weights.

Image for post
Image for post

Restore a training session

Finally, we will look a little bit deeper into what is saved in SavedModel and how a training session is restored. The checkpoint in the previous section does not save the model parameters only. It also contains the state of the optimizer (learning rate, decay) and any parameters related to the trainable parameters, for example, the momentum (m). It also contains the state of the training including the training step and the save counter appended to the name of the checkpoint file. Hence, when the checkpoint is restored, it also restores the state of the optimizer and the checkpoint’s states. It also checkpoints the progress of the dataset iterator. Therefore, the iterator can be resumed from where it stops instead of starting from the beginning.

Image for post
Image for post

checkpoint.restore restores variable values for any matching path from a checkpoint object, i.e. we can just load a subsection of the checkpoint. For example, we can recreate part of the model only, and in the example below, we just load the bias weights from the self.l1 dense layer checkpoint.

Image for post
Image for post

Copy Weights

The code below copy weights from one layer to another.

Image for post
Image for post

In the code below, even though functional_model_with_dropout contains an extra dropout layer compared to functional_model, the dropout layer does not contain any weight. So we can still copy weights from functional_model to functional_model_with_dropout using model.set_weights.

Image for post
Image for post

Credits & References

The code in this article is mostly originated from TensorFlow Guide.

Deep Learning

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store