The Smallest Generative Adversarial Network

Author

Erik-Jan van Kesteren

Introduction

Generative adversarial networks (GANs) have been great at many synthetic generation things, e.g. this person does not exist. But that’s complicated and hard to understand! In this post, I ask the question how do GANs work on the smallest possible scale? The goal is to create an intuitive conceptual understanding of GANs for people who are familiar with data, simulation, and a bit of neural networks (e.g., statisticians).

The full code as an an R script can be found here.

# setup with seeds for reproducibility
library(torch)
set.seed(45)
torch_manual_seed(45)

The toy problem

Assume the observed data is a sample from a true distribution. In this problem, imagine I don’t know the true distribution, so I want to estimate it. I can do this in several ways, for example kernel density estimation. Here, I take another approach: generative adversarial training. First, let’s create and plot the example data, generated from \(\mathcal{N}(1, 3)\):

# True distribution is normal(1, 3), we sample 10000 points
N <- 1000
y <- rnorm(N, 1, 3)

# we need to create torch tensor of this data to use in torch
y_torch <- torch_tensor(matrix(y), requires_grad = FALSE)

# let's look at the data
plot(density(y), bty = "L", main = "Density of real data y")
curve(dnorm(x, 1, 3), add = TRUE, col = "darkgrey")
rug(y)

The generator

In a generative adversarial network, a generator and a discriminator are fitted together, each with a different task: the discriminator tries to distinguish fake and real data as well as possible, and the generator tries to create fake data that the discriminator cannot distinguish from real.

Specifically, the generator is a function that maps a value sampled from a latent space onto the data space. The latent space can have many dimensions, but because here I want to create the smallest possible GAN I use only one latent dimension. In this model, I will sample this latent data \(\mathcal{N}(0, 1)\).

# an extremely simple generator with 2 parameters:
# a weight and a bias
generator <- nn_linear(1, 1)

# Let's try it!
as.numeric(generator(rnorm(1)))
[1] 0.9891596
Theoretical note

Because I know the real distribution is \(Y \sim \mathcal{N}(1, 3)\) and the latent distribution is \(Z \sim \mathcal{N}(0, 1)\), I know that \(Y = 3Z + 1\). I can later check if generative adversarial training correctly recovers this function: a weight of 3 and a bias of 1.

The discriminator

The discriminator is supposed to distinguish real data points from fake data points (created by the generator). It should have enough capacity to learn all relevant properties of the true data generating mechanism. In this case, I have only a mean and a variance, so I can use a relatively small feed-forward classifier with only two hidden nodes:

# an extremely simple discriminator with 2 hidden nodes:
discriminator <- nn_sequential(
  nn_linear(1, 2),
  nn_sigmoid(),
  nn_linear(2, 1),
  nn_sigmoid()
)

# Let's try it! Should give a number between 0 and 1
as.numeric(discriminator(rnorm(1, 1, 3)))
[1] 0.5259537

Let’s label true data as 1 and fake data as 0 and use binary cross-entropy as the loss function for the discriminator. In other words, this discriminator is basically a (non-linear) logistic regression.

is_real <- torch_ones_like(y_torch)
is_fake <- torch_zeros_like(y_torch)

criterion <- nn_bce_loss()

The training loop

After initializing the neural networks, they need to be trained. First, set up everything necessary to perform optimization:

# Two time-scale update rule: discriminator learning rate
# is twice as high as the generator learning rate
# https://arxiv.org/abs/1706.08500
optg <- optim_adam(generator$parameters, lr = 1e-2)
optd <- optim_adam(discriminator$parameters, lr = 2e-2)

n_epoch <- 500
dlosses <- numeric(n_epoch)
glosses <- numeric(n_epoch)

Then, start the training loop, which also includes a nice plot:

# Plot theoretical density to learn
curve(
  expr = dnorm(x, mean(y), sd(y)), 
  from = -10, to = 10, 
  ylim = c(0, dnorm(0)*2/3),
  main = "Density approximation by GAN",
  ylab = "Density", 
  xlab = ""
)

# Start training
for (i in 1:n_epoch) {
  
  # generate fake data
  inp <- torch_randn(N, 1, requires_grad = FALSE)
  y_fake <- generator(inp)
  
  # train the discriminator
  discriminator$zero_grad()
  
  # the discriminator loss is its ability to classify
  # real and fake data correctly
  prob_real  <- discriminator(y_torch)
  prob_fake  <- discriminator(y_fake)
  dloss_real <- criterion(prob_real, is_real)
  dloss_fake <- criterion(prob_fake, is_fake)
  dloss <- dloss_real + dloss_fake
  
  dloss$backward()
  optd$step()
  dlosses[i] <- dloss$item()
  
  # train the generator
  generator$zero_grad()
  
  # the generator loss is its ability to create
  # data that is classified by the discriminator
  # as real data
  prob_fake <- discriminator(generator(inp))
  gloss <- criterion(prob_fake, is_real)
  
  gloss$backward()
  optg$step()
  glosses[i] <- gloss$item()
  
  # Print current state
  if (interactive()) 
    cat("\r iteration", i, "dloss:", dlosses[i], "gloss:", glosses[i])
  
  # Plot current density estimate, assuming normal distribution
  if (i %% (n_epoch/10) == 1) {
    mu_hat <- as.numeric(generator$parameters[["bias"]])
    sd_hat <- abs(as.numeric(generator$parameters[["weight"]]))
    c_val <- (n_epoch - i) / n_epoch * 0.5 + 0.5
    curve(dnorm(x, mu_hat, sd_hat), add = TRUE, col = rgb(c_val, c_val, c_val))
    curve(dnorm(x, mean(y), sd(y)), add = TRUE)
  }
}

As the generator trains (lighter to darker grey lines) it approximates the true density (black line) better and better.

Then, it’s also possible to plot the losses for the discriminator and the generator, and see that they are indeed competing in a zero-sum game; as the generator loss goes down, the discriminator loss goes up (and vice versa):

par(mfrow = c(1, 2))
plot(glosses, type = "l", ylab = "loss", xlab = "Epoch", 
     main = "Generator loss", col = "darkblue", bty = "L")
plot(dlosses, type = "l", ylab = "loss", xlab = "Epoch", 
     main = "Discriminator loss", col = "darkred", bty = "L")

Lastly, inspect the parameters of the generator (the way we set it up, these should be the mean and sd of the target distribution). Then, also generate some fake data.

# inspect parameters
(mu_hat <- as.numeric(generator$parameters[["bias"]]))
[1] 0.9485998
(sd_hat <- abs(as.numeric(generator$parameters[["weight"]])))
[1] 3.118304
# generate fake data
y_hat <- as.numeric(generator(matrix(rnorm(1000))))

plot(density(y_hat), bty = "L", main = "Density of fake data y_hat")
curve(dnorm(x, mu_hat, sd_hat), add = TRUE, col = "darkgrey")
rug(y_hat)

Conclusion

In this post, I taught a generator to generate \(\mathcal{N}(1, 3)\) data by making it compete against a discriminator. The parameters of the generator were very close to the true values: 0.95 \(\approx\) 1 and 3.12 \(\approx\) 3. In other words, without having specified any likelihood, I have approximated maximum likelihood estimation for this distribution! For other distributions and with fewer assumptions, both the generator and the discriminator would need to be extended. The generator would need to do more complicated transformations to, say, generate uniform distributed data. The discriminator would also need to be more complicated to create nonlinear or multivariate decision boundaries.

In creating this smallest GAN, I’ve learned the following things:

  • The process requires quite some fine-tuning to the problem at hand. I started with much larger networks for the same problem, and they were performing very poorly.
  • The optimization needs tuning as well. In the end, I followed the two time-scale update rule which made a huge difference for the stability.
  • The discriminator absolutely needs to have enough capacity to detect all the relevant features of the real distribution that you want to approximate. This reminds me of approximate Bayesian computation, where you need to select the correct statistics to optimize for. If you don’t do this, the result will be mode collapse: the generator will just generate the mean of the distribution and the discriminator will think this fits perfectly with the true distribution.