Applications of Graph Neural Networks (GNN)

In the two previous articles on GCN and GNN networks, we present a design overview for Graph Neural Networks. In our final article, we will focus on GNN applications. Since the basic GNN theory is already covered by the mentioned articles, we will not repeat it here. And for the design details for each application, please also refer to the original research papers.

Medical Diagnosis & Electronic Health Records Modeling

Medical ontology can be described by a graph, for example, the following diagram represents the ontology using a DAG (directed acyclic graph).

To make use of the ontology in medical diagnosis, we can learn a network embedding gᵢ for each node cᵢ in the graph below — which is computed by attentions using the learned embedding e for node i and its parents. To make diagnosis predictions, like the type of disease or predicting the chance of heart failure, we multiply the ontology knowledge G with the current visit information xt and make the final prediction using a NN network. This model can be further enhanced with information on previous visits using RNN networks instead of a simple NN model.

The following is the scatterplot of the final representation gᵢ using t-SNE. As shown below, the learned embeddings cluster nicely falling into particular groups of diseases.

Drug discovery and Synthesize chemical compounds

The cost of developing a single drug is over $1 billion and it spans over 12 years or more. So it is commercially viable in applying AI if it can shorten the drug discovery for better drug candidates and the biomarker discovery phase, which can be over 8 years in total.

A DNN can be trained on hundreds of thousands of chemical structures to encode and decode molecules, as well as building predictors that estimate chemical properties from the latent representation. For example, we can learn an autoencoder to encode the graph representation of a molecular and then reconstruct the molecular with a decoder. The training objective is to learn a latent representation that minimizes the reconstruction loss. These latent representations allow researchers to automatically generate novel chemical structures by performing simple operations in the latent space by perturbing known chemical structures or interpolating between molecules. We can also use this latent representation to predict synthetic accessibility and drug similarity with another DNN (the green network below).

Here is another project at MIT in applying deep learning on a graph object in discovering new antibiotics.

The Open Catalyst Project is another example to use AI to discover new catalysts for use in renewable energy storage.

The diagram below is another view of applying AI in developing new drugs and vaccines from the existing database of molecular structures including the existing FDA-approved drugs. For example, the cures of cancers may depend on the design of antigens specific for the particular cancer cells. By starting with already approved drugs, it may shorten the discovery process as the toxicology screening may be studied already. Many AI components demonstrated below can be replaced with GNN technologies in modeling the molecular structure.

Modeling the Spread of COVID

Google utilizes the mobile data information through aggregated GPS analysis to create temporal and spatial edges between nodes (places) to model the mobility of people during the pandemic.

With COVID report data and Google mobility data, they can be combined to model the spread of the virus. For example, by applying GCN, it builds a latent representation of each node and makes node level predictions, like the COVID case counts.

Social influence prediction

Social influence prediction focuses on the impact of friends’ actions, particularly within social networks. For example, if some social network friends bought a jacket, will he/she v buy it also. With a social graph as input, DeepInf learns a network embedding (a latent social representation) for a user. Combined with handcrafted features in (d) below, it makes predictions on social influences, like whether v will also view an advertisement clip (step f). During training, it compares its predictions with the ground truth to learn this network embedding.

Recommender systems

Objects can be visually similar but are in fact totally different objects. For example, the top row below contains objects that are quite different from the intended image query on the left even though they are visually similar.

In Pinterest, boards connect pins together to form graphs. PinSage is a random-walk GCN that learns embeddings for nodes (images) in Pinterest graphs. Since the graphs contain billions of objects, performing convolutions on such a huge graph is not efficient. Instead, Pinterest constructs the graphs dynamically. It simulates random walks using sampling with weights according to the visit counts. This process constructs dynamic and much smaller graphs. Convolutions are later applied to compute the embedding of nodes.

In Uber Eats, it uses GraphSage to make recommendations.

Traffic forecasting

DCRNN incorporates spatial (spatial dependency on roadways) and temporal dependency (changing road conditions) in the traffic flow for traffic forecasting. Sensors in the roads are modeled as nodes in a graph. DCRNN captures the spatial dependency using bidirectional random walks on the graph, and the temporal dependency using the encoder-decoder.

Google DeepMind also uses GNN to estimatee the traveling time and plan routes according.

Scene graph generation

Given an image, we can generate a scene graph that describes the objects and their relationship in an image.

The model below generates the scene graph using GRUs and learns to iteratively improves its predictions via message passing.

F-net uses a bottom-up clustering method to factorize the entire graph into subgraphs, where each subgraph contains several objects and a subset of their relationships. By using a divide-and-conquer approach, the computation in the intermediate stage is significantly reduced.

In reverse, we can fabricate an image based on the scene graph.

For example, with an existing image, we can generate a river in a special area of the image.

Link Prediction

Link prediction predicts whether two nodes in a network are likely to have a link. In Recommender System, we recommend products that are highly “connected”.

SEAL extracts a local enclosing subgraph around A & B below with the link AB omitted. Then, the model is trained with GNN to predict such a link exists or not.

Point Cloud Classification & Segmentation

3D scanners like LiDAR produce 3D point clouds, a representation of objects in the 3D space with coordinates, and possible color information.

With GNN, we can segment the data points and classify them using the model below.

Here is an example of the raw 3D object data points and how they are segemented and classified to identify the objects.

And this is the model for generating the 3D segmentation for the example above.

Human-object interaction

GPNN explains a given scene with the graph structure. For example, it labels the link between the person and the knife with “lick”.

Text Classification

We can apply GNN for topical text classification including applications like news classification, Q&A, search result organization, etc …

In the model below, it slides a three-word window over a raw text to create a Graph-of-words. This graph indicates the word co-occurrence within a three-word range. Then it selects nodes from the graph based on the rank of each node (the number of connectivity of the node). For each node, it finds a 4-node subgraph that contains this node and four more nodes using breadth-first-search. The subgraph will be ordered such that convolution can be applied to all sub-graphs consistently.

The diagram below is the architecture of making label predictions from these subgraphs.

Sequence labeling

Words in a sentence can be modeled as nodes in a graph and we can compute the hidden representation of each node and use it to label the sequence (labels for the words in the sequence).

Potential applications include POS-tagging, NER (Named-entity recognition), and Semantic Role Labeling (SRL). SRL assigns labels to words or phrases in a sentence that indicate their semantic role as shown below.

So given a predicate (the word “disputed”) the model below identifies and labels all its arguments.

Relation Extraction in NLP

The diagram below represents various dependencies such as linear context (adjacent words), syntactic dependencies, and discourse relations.

This sentence suggests that tumors with L858E mutation in the EGFR gene respond to the drug gefitinib. If we have a triple defined as (drug, gene, mutation), these sentences will suggest the triple (gefitinib, EGFR, L858E) have a “respond” relation.

In the architecture below, words in the sentences are encoded with word embeddings. Then it uses a graph LSTM to learn a contextual representation for each word. Next, we concatenate the contextual representation for the words (gefitinib, EGFR, L858E) together. Finally, we use a relation classifier to score (classify) the relations of these three words. So the relation “Respond” (say R₄) should have the highest score.

Pose estimation

ST-GCN performs convolutions over its spatial and temporal neighbors to estimate the pose of the input video.

Chip design

In chip design, the placement and routing of standard cells impact the power, die size, and performance of a chip. Google demonstrated the use of GNN and Reinformance Learning to optimize cell placement.

The chip netlist (node types and graph adjacency information) graph and the current node to be placed are passed through a GNN to encode the input state.

These embeddings are concatenated with the metadata embedding (like the total number of wires) and feed into a NN. The output is a learned latent representation and serves as input to the policy and value networks in Reinforcement Learning. The policy network produces a probability distribution over all possible cell placement for the current node. Then, we can randomly sample the actions based on this probability distribution.

Particle Physics

In FermiLab, researchers implement GNN to analyze the images produced by the CMS detector at the Large Hadron Collider to identify interesting particles for the particle physics experiments.


Many GNN applications are classified as node classification, graph classification, network embedding, node clustering, link prediction, graph generation, spatial-temporal graph forecasting, and graph partitioning. Here is a more detailed list of possible application:

Deep Learning