TensorFlow Libraries and Extensions
In this article, we will overview some of the key extensions and libraries in TensorFlow 2.x (TF 2.x). This will include TF Datasets, TF Hub, XLA, model optimization, TensorBoard, TF Probability, Neural Structured Learning, TF Serving, TF Federated, TF Graphics, and MLIR.
It supports the loading of many popular datasets. For a complete list, checkout out the TF dataset category. Here is the code sample for loading MNINST data.
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow.data as tfd
# Construct a tf.data.Dataset
# Data loaded into ~/tensorflow_datasets/mnist
ds = tfds.load('mnist', split='train', shuffle_files=True)
# Build your input pipeline
ds = ds.shuffle(1024).batch(32).prefetch(tfd.experimental.AUTOTUNE)
for example in ds.take(1):
image, label = example["image"], example["label"]
assert image.shape == (32, 28, 28, 1)
assert label.shape == (32,)
TensorFlow Hub is a repository of trained machine learning models, like BERT, for fine-tuning and deployable models. For a complete list, check out TF hub. In the code below, we load a model for the token-based text embedding trained on English Google News 200B corpus.
#pip install --upgrade tensorflow_hub
import tensorflow_hub as hub
model = hub.KerasLayer("https://tfhub.dev/google/nnlm-en-dim128/2")
embeddings = model(["The rain in Spain.", "falls",
"mainly", "In the plain!"])
assert embeddings.shape == [4, 128]
A trained model can be further optimized with no or minor accuracy loss using weight pruning and/or quantization. The code below optimizes an already trained “model” using weight pruning.
The memory footprint can also be reduced using weight clustering. It first groups the weights of each layer into N clusters, then shares the cluster’s centroid value for all the weights belonging to the cluster.
TensorBoard is a visualization tool for a TF application. It displays the scalar metrics logged by the application (like accuracy and loss), the input data, the computation graph, and the distributions and histograms of the trainable parameters.
In a TF application, we store information into files that can be read from the tensorboard application.
TensorFlow Probability (TFP)
TFP provides a library to model probabilistic distributions, variational inference, Markov chain Monte Carlo, etc…
The code below samples 100K data from a normal distribution and manipulate it to sample 100K Bernoulli distribution data. With the data collected, the code fit the Bernoulli distribution with these data and find the model parameters.
mport tensorflow as tf
import tensorflow_probability as tfp
# Pretend to load synthetic data set.
features = tfp.distributions.Normal(loc=0., scale=1.).sample(int(100e3))
labels = tfp.distributions.Bernoulli(logits=1.618 * features).sample()
# Specify model.
model = tfp.glm.Bernoulli()
# Fit model given data.
coeffs, linear_response, is_converged, num_iter = tfp.glm.fit(
# ==> coeffs is approximately [1.618] (We're golden!)
Neural Structured Learning (NSL)
In computer vision, information is encoded in an image. In NLP, it is contained in a text. However, there is rich information that can be encoded in a graph to describe relationships between samples. The Cora dataset is a citation graph where nodes represent machine learning papers and edges represent citations between pairs of papers. We can utilization both the nodes (the paper content) and the link (citation) to categorize each paper into one of seven categories better. In the diagram below, we want the embedding features among neighbors to be similar.
For example, we can introduce a neighbor loss to penalize the difference (D) of the neighbors’ embeddings.
With the additional edge information, we can perform graph regularization in performing document and sentiment classification.
In computer vision, we can add noise into pixels of an image to generate artificial neighbors to avoid an adversarial attack. Here is the code in generating an adversarial regularization model on top of a deep learning model using NSL.
TensorFlow Serving (TFS)
TensorFlow Serving serves requests from a client in production environments for machine learning models. The commands below create a docker with the TFS in deploying a model for y = x/2 + 2.
Below, the SavedModel “my_model” is served at port 8501.
TensorFlow Federated (TFF)
TFF performs model training on decentralized data. TFF can train this model across participating clients that keep their training data locally. For example, a mobile phone can train a model without uploading sensitive user data to servers.
In Federated Learning, the client devices compute SGD updates on locally-collected data. The model’s updates are collected and aggregated in a remote server instead of sending the more sensitive user data to that server. Finally, the aggregated model is sent back to the client.
Tensorflow Graphics provides differentiable graphics and geometry layers (e.g. cameras, reflectance models, spatial transformations, mesh convolutions) that can be used to train machine learning models. And we can use the 3D TensorBoard to visualize the 3D renderings.
For example, we can train a neural network model to decompose an image to the corresponding scene parameters that can be used to renderer a scene. Such models will be trained to minimize reconstruction loss.
The following is the sample code in rendering an object using the TensorFlow Graphics.
MLIR creates an intermediate representation of a machine learning model or algorithm to be executed on an AI accelerator. This representation bridges the gap between the logical model and the physical chip design. The goal is to allow such a representation can be better optimized and executed in AI chips.
XLA, a JIT compiler, takes computation graphs and performs optimization in combining and removing redundant computation nodes. Then, it compiles them into sequences of kernels for the target devices with further optimizations. For example, in GPU devices, XLA combines computation nodes that can be performed in a single GPU operation.
Credits and References
The computer codes in this article are adopted or modified from the TF guide and documents.
This article is an overview of TF extensions and libraries. However, TF design changes constantly. Please refer to the latest documentation for detailed APIs and implementations.