Neural Turing Machines: a Fundamental Approach to Access Memory in Deep Learning
Memory is a crucial part of the brain and the computer. For example, in question and answer, we memorize information that we have processed and use them to answer questions.
From the Neural Turing Machine (NTM) paper:
We extend the capabilities of neural networks by coupling them to external memory resources, which they can interact with by attentional processes.
This memory resource will be realized as an array structure to read and write from it. Sound simple: not exactly. First, we do not have an unlimited memory capacity to hold all images or voices we encountered, and we want the information to be retrievable by similarity and relevancy. In this article, we discuss how NTM accesses information. We are interested in this paper because it is an important starting point in many research areas including NLP and meta-learning.
Deep Learning (DL) memory structure Mt contains N rows (N objects) each with M elements. Each row encodes a piece of information, for example, some representation of the latent factors of your cousin’s voice in M-dimensions.
In conventional programming, we access memory by index Mt[i]. But for AI, the information should also be retrievable by similarity (by content). So NTM derives a reading mechanism using weights that measure the similarities between the input and each memory row. The recalled memory output will be a weighted sum of these memory rows.
where wᵢ is the weight that we should pay attention to memory row i. All wᵢ is computed with softmax and therefore they add up to one.
You may immediately ask what purpose does it serve. Let’s go through an example. A friend hands you a drink. It tastes like tea and feels like milk. By extracting our memory profile on tea and milk, we apply linear algebra to interpolate the final result and find out that it is a boba tea. Sound like magic. But in word embedding, we use the same kind of linear algebra to manipulate relationships.
So how do we create those weights? A controller extracts features kt from the input using a deep network (an LSTM or a forward feeding network) and we use it to compute the weights. For example, you take a phone call but you cannot recognize the voice immediately. The voice sounds a whole lot like your cousin but it also resembles the voice of your elder brother. The recalled memory output will be a weighted sum of your cousin’s and brother's voice based on similarity.
Mathematically, to compute the weight wᵢ, we measure the similarity between kt and each of the memory entry. We calculate a score K using cosine similarity.
Here, u is our extracted feature kt, and v is each individual rows in our memory.
We apply a softmax function on the score K to compute the weight wᵢ. βt is added to amplify or attenuate the difference in scores. For example, if it is greater than one, it amplifies the difference. w retrieves information based on similarity and we call this content-based addressing.
So how we write information into memory. In LSTM, the internal state of a cell is a combination of the previous state and a new input state.
It trains and computes the forget gate f to control what previous states should be forgotten (or erase) and the input gate i to control what states should be added to the current cell.
Borrow from the same intuition, the memory writing process composes of previous state and new input. Here, we erase part of the previous state
where et is an erase vector — acts like the forget gate in LSTM.
Then, we write our new information.
where at is what we want to add — acts like the input gate in LSTM.
In DL problems, et and at are trainable parameters output by a DNN, say an MLP network with input from the hidden state ht of the LSTM cell.
Here, through a controller that generates w, we read and write from our memory.
w acts as an addressing mechanism in NTM in accessing memory. So far, we use the content-based addressing in computing w. But in NTM, the addressing mechanism can be further enhanced to include location-based addressing. For example, the location-based addressing allows NTM to implement variable-based computing like a=b+c where we can access a specific location in memory. (But we will not elaborate this example further since this is not our scope.) Let’s look at different addressing mechanisms that can be added to NTM.
w represents what is our current focus (attention) in our memory. In content-based addressing, our focus is only based on the current input. However, this does not account for our previous encounter. For example, if your classmate texts you an hour ago, you should be more likely to recall him. How do we accomplish previous attention in extracting information? We compute a new merged weight based on the content-based focus (w^c_t) as well as our focus in the last timestep. Yes, this sounds like the gating mechanism in LSTM or GRU.
where gt controls how much content-based addressing and how much w information from the last time step should be kept.
Convolution shift handles a shift of focus. For example, we can shift every focus by 3 rows. i.e. w[i] ← w[i+3] (the weight for row i+3 becomes the weight for row i).
In general, convolution shift creates a focus from a range of row(s), i.e. w[i] ← convolution(w[i+3], w[i+4], w[i+5]) where the convolution function is a linear weighted sum of rows, for example 0.3 × w[i+3] + 0.5 × w[i+4] + 0.2 × w[i+5].
This is the mathematical formulation to shift our focus (all the index is computed with modulo N):
In engineering, this is simply the definition of convolution. This mechanism expands an NTM to perform basic algorithms like copying and sorting. In many deep learning models, we can skip this step by setting s(i) to 0 except for i=0 where s(0) = 1.
Our convolution shift behaves like a convolutional blurring filter. So we can apply the sharpening technique to our weights to counterplay the blurring if needed. γt will be another parameter output by the controller to sharpen our focus.
Putting it together
All these addressing mechanisms can be merged to form an attention focus on which rows to be accessed. Here is the generic system diagram in which a controller outputs the necessary parameters and integrate different addressing mechanisms at different stages. The final attention focus used will be wt.