Implementation of Generative Adversarial Networks

With deep learning, more data has direct impact on model accuracy 12. The data collection and transformation however is a challenging task, so wouldn’t it be nice to able to generate endless stream of synthetic data from the initial dataset ? Imagine being able to learn the probability distribution of your dataset and being able to sample it to train your network. It was my main motivation to learn more about Generative adversarial networks and as I read more about GANs I found equally interesting ideas covered under use cases section below.

Generative adversarial networks (GAN), created by Goodfellow et al, provide a framework to synthetically generate training data and more provided you have the initial dataset that represents the problem domain and you are able train your GAN successfully (see challenges).

GANs make use of two differentiable networks, Generator (G) and Discriminator (D) where G’s goal is to learn the probability distribution of target dataset and D’s goal is to be able to successfully differentiate between the real data from target dataset and the fake data generated by G. G takes a random noise as input and generates the output in a format understood by D. D is trained by passing real as well as fake data to assign a high probability to real and low probability to the fake data. G is trained to maximize the probability of D assigning the fake data as real. As the training continues, G learns to generate realistic looking data and D can no longer differentiate between the real and fake/generated data and assigns the probability near 0.5 to its inputs.

In other words, D and G play two player minimax game : minGmaxDV(D,G)=Expdata(x)[logD(x)]+Ezpz[log(1D(G(z)))]\min\limits_G \max\limits_D V(D,G) = \Bbb{E}_{x \sim p_{data}(x)}[log D(x)] + \Bbb{E}_{z \sim p_z}[log(1-D(G(z)))]

MNIST GAN

We use MNIST dataset of handwritten digits to train our model. The learned model is able to generate realistic looking images after 15k iterations.

Images from mnist dataset (Original)Images generated by generator (Generated)
Original Generated

If you wish to jump straight to the code, here is the Jupyter notebook.

The short clip below shows the output of Generator as it learns to fool the discriminator. The clip starts with Generator producing noise and over time learning to generate digits.


Over the course of its training, the mean and standard deviation of generated data move closer to that of original.

IMAGE ALT TEXT HERE IMAGE ALT TEXT HERE

Network Architecture

We use fully connected layers in our simple network so training can be done entirely on CPU running on a laptop. The images in our dataset are gray scale 28 x 28 pixels (i.e small and single channel) so the Generator with only fully connected layers has no problem learning the data distribution. For larger images, it is advised to make of use Convolutional layers, a topic that we will explore in a future blog.

Generator Activation Output Params
Input - 100 0
Dense LeakyReLU 256 25856
BatchNorm - 256 1024
Dense LeakyReLU 512 25856
BatchNorm - 512 2048
Dense LeakyReLU 1024 525312
BatchNorm - 1024 4096
Dense Sigmoid 784 803600
Reshape - 28x28x1 0
Total Params 1,493,520
Discriminator Activation Output Params
Input - 28x28x1 0
Flatten - 784 0
Dense LeakyReLU 512 401920
Dense LeakyReLU 256 131328
Dense Sigmoid 1 257
Total Params 533,505

Fashion-MNIST GAN

This is another dataset with 28x28 grayscale images of clothes and accessories that is intended as drop in replacement for MNIST dataset. We use same architecture as above for training and the learned model is able to generate grainy but somewhat realistic looking images.

Images from dataset (Original)Images generated by generator (Generated)
Original Generated

As with MNIST, the generated images for Fashion-MNIST look noisy. I think a Generator with Convolutional layers will produce better looking images and I intend to explore it in future notebook or post.

Use cases

GANs have grown in popularity ever since they were created in 2014. The following contains the list of applications that I find interesting:

  • Synthetic data generation
  • Generate images from text
  • Increase image resolution
  • Image inpainting: e.g fill up the holes in old scanned images, generate new face from partial image of face
  • Transfer image from one domain to another: e.g draw a cartoon and generate realistic looking image, take a day scene photograph and turn it into night scene

There is a github repo that lists all GAN papers along with the arxiv link.

The plot below shows the popularity of GAN (GANPlosion of papers) ganplosion

Last updated: 01-Dec-2018

Challenges

Mode Collapse

The choice of initialization, optimizer, loss function and network architecture make a huge difference on convergence. The most common case of failure in GANs is mode collapse. For example, the Generator trained on MNIST dataset may learn to output a single number every time and Discriminator will learn to label the image with that number, both generated as well as from real, as fake. The Generator will then move to another number and Discriminator will learn to mark the image with that number as fake. Here, the Generator is generating images from a single cluster where as we want it to sample from multiple clusters simultaneously.

The following excerpt from Improved Techniques for Training GANs explains the challenge of convergence really well. The paper is a good read to learn more about techniques that can make GANs train better and overcome common issues.

Training GANs requires finding a Nash equilibrium of a non-convex game with continuous, high- dimensional parameters. GANs are typically trained using gradient descent techniques that are designed to find a low value of a cost function, rather than to find the Nash equilibrium of a game. When used to seek for a Nash equilibrium, these algorithms may fail to converge

My Notebooks:

  1. MNIST Colab
  2. Fashion MNIST is a drop in replacement so you just need to download the dataset and replace the mnist part.

Datasets:

  1. MNIST: http://yann.lecun.com/exdb/mnist/
  2. Fashion MNIST: https://github.com/zalandoresearch/fashion-mnist

Footnotes:

  1. Model accuracy improvements [Data is king] (http://research.baidu.com/deep-learning-scaling-predictable-empirically/)
  2. Answer on Quora

References:

  1. Generative Adversarial Networks link
  2. Udacity AI Nanodegree link
  3. Ian Goodfellow’s NIPS 2016 Tutorial link
  4. Article on GAN challenges link
  5. Improved Techniques for Training GANs link