TensorFlow Custom training, Transfer learning & Custom layers

Jonathan Hui
12 min readJan 20, 2021


In this article, we discuss lower-level APIs in modeling, building, and training a model. This gives us more flexibility, like those required for GAN. We will also show examples of performing transfer learning.

Model Training with GradientTape

Keras provides training and evaluation methods, fit() and evaluate(), out-of-the-box. By declaring the optimizer and the loss function to be used, the gradient descent will be performed automatically. However, for advanced optimization algorithms, we need access to the inner mechanism of computing gradients and performing gradient descent. In TensorFlow (TF), this can be done using tf.GradientTape.

The example in this section uses a lower level APIs to build and to train a regression model (y = wx + b). First, GradientTape records (tapes) the forward pass. Later, to perform backpropagation, we use the recorded tape t to compute the loss gradient w.r.t. w and b. Finally, the model parameters are updated using gradient descent. By spelling out these extra steps, it allows us to inject codes that can manipulate and apply gradients explicitly for customized optimization algorithms. Let’s look at the code for a custom model, a dense model, first.

In __init__ of MyModel, we create trainable parameters w and b. The callable __call__ is a simple Python mechanism to produce the model output when model(input) is called. Next, we record the forward pass operations with GradientTape t. Then, in backpropagation, t.gradient computes the loss gradient w.r.t. w and b respectively. (Please refer to this article if you need more information on the AutoDiff.) Next, we update w and b with gradient descent using assign_sub. The last part of the code performs the model training.

Once the training is finished, we visualize the data and the regression model.

MNIST with GradientTape

Here is another example of training a neural network (NN) model, an MNIST classifier. The first part of the code prepares the Dataset. It wraps around the NumPy ndarray x_train and y_train to create mini-batches of 32 samples.

Then, we create a model.

Next, we will define the optimizer, the loss function, and the metrics to be used.

Here are the training and testing step. In training, we make predictions and compute the loss from the defined loss function. Next, it computes the loss gradient w.r.t. all trainable weights. model.trainable_variables keeps a list of trainable parameters so we do not need to do it ourselves. Then, we apply our choice of the optimizer to perform the gradient descent using the computed gradients. Finally, we log both selected metrics (loss and accuracy). Testing has similar steps without the gradient descent optimization.

Finally, we run the training and at the beginning of every epoch, we reset the metrics.

Here is another version. It uses functional APIs to build a model directly. But it uses dense layers only. The code also splits part of the training dataset into a validation dataset.

The training iteration is similar. But we use an SGD optimizer instead. Also, the training includes an extra step of validation check.

GAN with GradientTape

Custom gradient descent is often used in GAN. The code below creates a discriminator and a generator for GAN in which the generator is based on transpose convolution (line 30).

Here is the model summary for the discriminator and the generator. As shown, the global_max_pooling2d finds the maximum values among its spatial dimension.

Then, we sample latent vectors from a normal distribution and generate images from the generator (line 51). We concatenate the generate images and real images to form a large Tensor. To create the corresponding labels, we concatenate an all-ones Tensor for the generated images and an all-zeros Tensor for the real images together (line 54). Then, we compute the cross-entropy loss for the discriminator — we expect the discriminator to classify real images as real (0) and generated images as fake (1).

In line 66, we generate latent factors again and label them as 0. We pass these factors through the generator and the discriminator in line 72. Again, we compute the cross-entropy and expect the discriminator will misidentify these images as real (label 0). The loss gradient for the discrimination and the generator will be used to update their own parameters respectively (line 64, and 75). Finally, the code below prepares the dataset and the training.

TensorFlow Hub

TF Hub provides pre-trained models and layers. For example, we can download a pre-built embedding layer to encode a movie text review into a 50-D vector.

Here is a summary of the model built. We feed the 50-D vector from the pre-trained embedding layer into 2 dense layers to predict whether it is a good review.

Transfer Learning

DL models are hard to train. Therefore, many projects start with a pre-trained model with transfer learning. Just for completeness, here is the code in preparing the datasets for images that contain a dog or cat.

Then we create a pre-trained Mobilenet V2 model and we instruct all the weights not to be trainable.

Next, we add our own classification head in predicting the image as a cat or a dog.

We build a new model using the base_model and the new head. It also contains a preprocess_input layer to scale the input to the range of [-1, 1].

Next, we train the model. Because we set base_model.trainable = False before, only the added classification head will be trained.

Once 10 epochs are done, we want to finetune the model including some layers in the base_model. First, we set the base_model.trainable = True but then set the first 100 layers not to be trainable again. Whenever the training configuration or the model itself is changed, the model must compile (model.compile) again for the changes to take effect. Finally, we fit the model again. In model.fit, we set the initial_epoch to 11 (history.epoch[-1]) to restart the training. This allows optimizers to resume training with the correct learning rate.

Next, we evaluate the new model, make some predictions, and plot them out.

There is one important line of code early that makes the finetuning works.

Basically, it instructs the model to use the inference mode even during the training. In some layers, the executions in inference mode are different from the training mode. For example, the dropout layer will be ignored by model.fit when the model is in the inference mode. It is skipped. This setting is different from the setting layer.trainable which only indicates whether the layer’s parameters should be updated in backpropagation. Let’s demonstrate it with the batch normalization (BN) during the finetuning. In BN, training=False instructs the layer not to use the mean and variance of the current batch. Instead, it uses means and variances of the previous training. We want to finetune the model under the statistics collected from the previous training using a much larger dataset.

Custom Layer

Kersa comes with many predefined NN layers. But we can also create custom layers ourselves by extending layers.Layer. The following code implements a dense layer.

Just as a reference, here is the code for a custom layer in which you want different behaviors in the training and the inference mode (as discussed in last section).


In TF, it is not required to have the input_shape known when a layer or a model is instantiated (line 18).

But without this information, a layer cannot create its layer weights. TF does not have the information to know the shape of the weights yet. The weights can have a shape of (5, 2), (50, 2), or something else in the first dense layer. But when the model is first invoked with input in line 27, TF will call “model.build” to instantiate the model parameters with the input shape (this will be (3, 3) in our example). But if the model is called with parameters of different input shapes later, an exception will be thrown.

Alternatively, instead of invoking the model, we can also call build explicitly with the input shape to instantiate the weights. In line 14 below, a BatchNormalization layer is instantiated. But since no input shape is provided, no variables are created. But when “build” is called with an input shape, 2 trainable parameters and 2 non-trainable parameters (mean and variance) are created.

Here is an example of initializing the FC layer parameters within “build”. The code below also includes a get_config method so a model can be instantiated from a model configuration of another instance.

As a footnote, if a model has not been built, model.summary will throw an exception.

The following example shows how to add a non-trainable variable to a layer that sums up the inputs.


A layer can compose of other layers. It is a best practice to instantiate them in __init__ such that they will be built when this layer is built.

add_loss (optional)

We can add a loss value in the callable. To retrieve the losses of the layer and its sublayers, call layer.losses.

We can also add a loss to a model.

This is another example of adding a loss to a model.

In the custom optimization code below, we retrieve all the losses to perform the gradient descents.

If model.fit is used to train the model, all the added loss will be automatically included in the gradient descent. Nothing else is needed. Hence, we can simply use model.fit to train a model.

add_metric (optional)

We can also add a metric in a layer.

Custom Models

This section shows how to create a custom model. The APIs for a model and a layer are similar. Indeed, keras.Model inherits from keras.Layer. But for a model, there will be extra APIs like fit, evaluate, save, etc …


Let’s put them together to build an autoencoder.

The autoencoder composes of 3 layer classes: Encoder → Sampling (part of the encoder) → Decoder.

The key objective of the encoder is to create a 32-D latent factor z that composes of means (dense_mean) and log variance (dense_log_var) of a Gaussian distribution.

Then we use Sampling sublayers to sample a value based on this Gaussian distribution.

Then, we regenerate the image using this sampled latent factor using a dense layer in the Decoder.

Next, we compose a VAE model using these custom layers.

Without proof here, we want to maximize the log p(x) — the probability distribution of the image with the equation below. (The proof is very tedious so let’s take the word for it now.) It includes a term that can be interpreted as having distribution z not to be diverted from a standard normal distribution measured with a KL-divergence.

So in VariationalAutoEncoder.call, it adds a KL divergence loss.

Finally, here is the code for preparing the dataset and the training steps.

In line 100, we add an MSE reconstruction loss for the reconstructed images. And then, we add it up with the model losses (the KL-divergence) added.

Actually, we do not make any special customization to the training iteration. Therefore, we can simply use model.compile and model.fit to train the model with an added MSE loss function defined.


During training, we can pad variable-length samples to fixed-length so we can pass the samples as a tensor. This happens most often in sequence models.

Mask generating layers

In mask-generating layers like Embedding, when mask_zero is set to True, the layer output will also include a property called _keras_mask. It is a mask that can propagate to subsequent layers to indicate what part of its input can be ignored.

Or we can add a keras.layers.Masking layer such that later layers, say a LSTM layer, can ignore certain timesteps. Consider a NumPy ndarray of shape (batch_size, timesteps, features). We set timestep 3 and 5 to zero to indicate that these values are missing. With mask_value = 0, it will create a mask for the later LSTM layer such that it can skip those timesteps.

Passing a mask implicitly

When a Sequential model or the functional API is used, the mask will be propagated automatically to later layers that are capable of using them.

Passing a mask explicitly

Or we can retrieve and pass the mask explicitly to a layer (lines 56 and 57).

Mask generating custom layer

If a custom layer can generate a mask (like the embedding layer), it implements the compute_mask method.

This is another possible implementation.

Opt-in in passing a mask

By default, a custom layer will destroy the current mask. However, to allow the mask to propagate forward, set self.supports_masking to True. Any later layers will not receive the mask also.

Process the mask

If the output of a layer depends on the mask, they can add a mask parameter in the method call. In the example below, all the scores from a skipped entry will be ignored in calculating the softmax value.

Credit & References

The code in this article series is originated from the TensorFlow guide.