Visual Inspection of GAN Training

Desktop Version

Generative Adversarial Networks1, GANs for short, are one of the most popular generative models. Generative models learn to generate the probability distribution of the features they are trying to model. For example, you can use GANs to generate anime faces2, icons3, fonts4, and even human faces5. Synthetic Data Generator, an alternate name if I may.

In this post, we are going to inspect the images generated by GAN. But, with a twist. Instead of merely displaying the generated images, we will run them through the clustering algorithm and overlay them over the cluster of real images. The idea for this visualisation came from Sharon Zhou's tweet.

Please note that this post is not mobile friendly yet

Training and Clustering of Generator Images

GAN Training and Clustering of Generated Images

Dataset

We use MNIST6 dataset of handwritten digits.Our GAN will learn to generate these digits. You can read more about the dataset here.

A snapshot of MNIST digits ➡

Cluster MNIST Digits

We run 60k handwritten digits from the MNIST dataset through the clustering algorithm and visualise its output in 2D. We see the nice and clear cluster emerge for most images.

  • The cluster of digits 4, 7, and 9 are close to each other
  • Similarly, the clusters of 3, 5, and 8 are quite close
  • 0, 1, 2, and 6 are clear and separated

Next, We are going to train our network to generate digits.

Model Architecture

Generative Adversarial Network consist of two neural networks: the Generator network and the Discriminator network

The Generator's training objective is to produce a real-looking handwritten digit, which is (almost) indistinguishable from the real dataset.

The goal of Discriminator is to distinguish the fake images from the real ones clearly

  • Fake - the images generated by the Generator
  • Real - the real handwritten digits from the dataset

The model and training approach are very similar to my post7 on GAN.The post also includes the Colab Notebook8 for you to play around.

Training Begins

In the beginning, the Generator is quite bad at faking the images, whereas the Discriminator is quite good at separating real handwritten digits from the fake ones.

These images look like random noise, which is what they are. After all, the Generator takes random noise sampled from Gaussian distribution as its input. The discriminator at this point is very confident about its decision.

The clusters are still the same as previous plot. Additionally, We are now running the generated images through same clustering algorithm and plotting alongside the real ones in 2D.

  • Real - the real handwritten digits from the dataset
  • Fake - the images generated by the Generator

Shaky Ground

After several failed attempts, something magical6 happens. The Generator begins to make a somewhat crude representation of the digits.

At this point, there is a real danger where Generator can get stuck in local minima and collapse on one or a few modes.

Full training animation is available under the Full Video tab at the beginning of this post

The Generator Rises

The Generator survives mode collapse

Over time, the Generator learns to get better. The Discriminator becomes less sure whether the given image is real or fake

In the end, our Generator emerges victorious. It has learned to fool the Discriminator. The Discriminator is no longer sure whether the digit is real or fake.

The generated digits start to fall under the same cluster as the real digits. For example, the ones (1s) from MNIST handwritten dataset and generated images are now part of the same cluster.

From Noise to Digits

The Generator created these digits.

Considering it had to start from the random noise, it has done a good job here. However, there is still room for improvement. For example, we see that some images are not very clear or are incomplete. The network has trouble generating zeros, threes, and eights.

You can get better results with DCGAN or Wasserstein GAN.

The Next Steps

References:
  1. Generative Adversarial Networks  [LINK]
  2. Generate Anime Faces  [LINK]
  3. Generate Icons using DCGAN  [LINK]
  4. Generate Fonts  [LINK]
  5. Generate Human Faces using StyleGAN2  [LINK]
  6. MNIST Dataset  [LINK]
  7. My 2017 introductory post on GANs [BLOG]
  8. GAN Notebook  
  9. PS: Magical should not be taken in literal sense. The network learns through Backpropagation [BLOG]
  10. DCGAN Experimentation on various datasets (Already ancient as per GAN timeline) [BLOG]