Using Attention to Improve Neural Image Caption Generator

18 Jan, 2020

Introduction

In this post, I attempt to explain the attention in Neural networks in the context of image captioning. And, how it improves the model performance and provides some interpretability. We will manually inspect the captions generated by the vanilla model against that of the attention powered one. We will also glance through the automatic evaluation criteria used for comparing models in the NLP world and review the pros and cons of using such evaluation metric.

Recollect that, in the previous post, we used vanilla seq2seq architecture to create a model that could caption the images. The seq2seq architecture has two parts: encoder and decoder. We use the encoder to convert the input image into a low dimension image vector. The image vector is then used by the decoder to generate the captions. The decoder generates the caption by producing a single word at every time step. This is done using a combination of word generated at the previous step and the image vector. It continues to do so until it reaches the end of sentence marker or a pre-defined max number of timesteps.

We will make many references to the neural machine translation (NMT) because innovations such as seq2seq, attention were introduced there. These were later adopted by other domains in NLP as well as Computer Vision.

Motivation

In NMT, we pass the source sentence through the encoder one word at a time. At each timestep, the encoder updates its hidden state, and we expect the final hidden state hNencoder{h_N^{encoder}} to encapsulate enough information to allow the decoder to generate the translation. The decoder makes use of thehNencoderh_N^{encoder} along with its own internal hidden state htdecoder{h_t^{decoder}}, to generate one word at a time. We seem to be asking a lot from the final hidden state of encoder, and indeed, the layer highlighted in red is the information bottleneck.

We visualise the NMT architecture below with input in the Sanskrit language and its English translation.

The input source is taken from the Bhagavad Gita, the ancient Sanskrit scripture, that talks about focusing on the process rather than the results

Example of seq2seq NMT model

If you and I were to caption an image, we would most likely look at specific parts of the image as we come up with the caption. In contrast, our model looked at the entire image (vector) at every timestep. What if we could teach the network to focus on certain parts instead? Similarly, in NMT, what if the decoder could access all the hidden states in encoder and somehow learn to decide how much should it focus on each h1encoder,h2encoder,...,hNencoder{h_1^{encoder},h_2^{encoder},...,h_N^{encoder}}  to generate the next word in target language. And, this motivates the concept of attention in neural networks.

Attention

Attention allows the neural network, the decoder in case of image caption, to focus on the specific parts of the image as it generates the caption. Before we see how it is done, let's visualise attention using the cartoon below.

Click on the play button to see the animation

We overlay the attention heatmap to visualise what parts are in focus at each timestep. As you can see, the decoder now focuses on certain parts of the image as it decides the next word. At this point, I must emphasize that without attention, the decoder in previous, used the entire image vector.

The hidden state of decoder hNdecoderh_N^{decoder} is called query, and the hidden state of encoder hNencoderh_N^{encoder} in called value. In our case, the value is simply the image vector (i.e the output of CNN based encoder). We calculate attention using the hidden state of decoder at a particular timestep and the entire image vector from the CNN based encoder. With that, its time for the definition:

Attention is the weighted sum of values dependent on the query

In the case of NMT seq2seq, usually, both encoder and decoder are some variants of RNN and hence have internal hidden states. The decoder makes use of hidden states from all steps (h1encoderh_1^{encoder}, h2encoderh_2^{encoder}, ....,hNencoderh_N^{encoder}) to calculate attention score. This means that decoder is no longer restricted by the limitations of relying on the final hidden state hNencoderh_N^{encoder}. Thus, attention provides a solution to the information bottleneck problem we saw earlier.

General Framework

There are several variants of attention, but the process of generating attention generally follows the following three steps :

  1. Attention Score: Calculate attention score, etRNe^t \in \R^N, using the hidden state of encoder  hiencoderRhh_i^{encoder} \in \R^h and the hidden state of decoder  stRhs_t \in \R^h
  2. Attention Distribution: Calculate attention distribution using softmax over all hidden states αt=softmax(et)RN\alpha^t = softmax(e^t) \in \R^N
  3. Attention Output: Calculate attention output, also known as context vector, by taking the weighted sum of the encoder hidden state and attention distribution at=i=1NαithiencoderRha_t=\sum_{i=1}^{N} \alpha_i^t h_i^{encoder} \in \R^h

We then concatenate attention output ata_t and the decoder hidden state sts_t and continue with rest of the forward pass depending on the architecture (e.g., In this case, we kept the architecture same as the vanilla seq2seq model, GRU -> fully connected)   [at;st]R2h[a_t;s_t] \in \R^{2h}

These steps are discussed in the context of NMT but are also applicable to the image caption model. Instead of encoder hidden state, we just make use of image vector. If it helps, imagine setting hencoder=image vectorh^{encoder} = \text{image vector} and re-reading the steps above.

Types of Attention

There are several ways to compute the attention score. In this post, we cover three common ones listed below. If you're interested in learning more, I recommend this detailed post on attention mechanisms (e.g Hard vs Soft attention, vs attention).

Dot Product Attention

The most basic but fastest form of attention in terms of compute. We calculate the attention score using decoder hidden state sRhs \in \R^h and the hidden state