# setup with seeds for reproducibility
library(torch)
set.seed(45)
torch_manual_seed(45)
The Smallest Generative Adversarial Network
Introduction
Generative adversarial networks (GANs) have been great at many synthetic generation things, e.g. this person does not exist. But that’s complicated and hard to understand! In this post, I ask the question how do GANs work on the smallest possible scale? The goal is to create an intuitive conceptual understanding of GANs for people who are familiar with data, simulation, and a bit of neural networks (e.g., statisticians).
The full code as an an R script can be found here.
The toy problem
Assume the observed data is a sample from a true distribution. In this problem, imagine I don’t know the true distribution, so I want to estimate it. I can do this in several ways, for example kernel density estimation. Here, I take another approach: generative adversarial training. First, let’s create and plot the example data, generated from \(\mathcal{N}(1, 3)\):
# True distribution is normal(1, 3), we sample 10000 points
<- 1000
N <- rnorm(N, 1, 3)
y
# we need to create torch tensor of this data to use in torch
<- torch_tensor(matrix(y), requires_grad = FALSE)
y_torch
# let's look at the data
plot(density(y), bty = "L", main = "Density of real data y")
curve(dnorm(x, 1, 3), add = TRUE, col = "darkgrey")
rug(y)
The generator
In a generative adversarial network, a generator and a discriminator are fitted together, each with a different task: the discriminator tries to distinguish fake and real data as well as possible, and the generator tries to create fake data that the discriminator cannot distinguish from real.
Specifically, the generator is a function that maps a value sampled from a latent space onto the data space. The latent space can have many dimensions, but because here I want to create the smallest possible GAN I use only one latent dimension. In this model, I will sample this latent data \(\mathcal{N}(0, 1)\).
# an extremely simple generator with 2 parameters:
# a weight and a bias
<- nn_linear(1, 1)
generator
# Let's try it!
as.numeric(generator(rnorm(1)))
[1] 0.9891596
Because I know the real distribution is \(Y \sim \mathcal{N}(1, 3)\) and the latent distribution is \(Z \sim \mathcal{N}(0, 1)\), I know that \(Y = 3Z + 1\). I can later check if generative adversarial training correctly recovers this function: a weight of 3 and a bias of 1.
The discriminator
The discriminator is supposed to distinguish real data points from fake data points (created by the generator). It should have enough capacity to learn all relevant properties of the true data generating mechanism. In this case, I have only a mean and a variance, so I can use a relatively small feed-forward classifier with only two hidden nodes:
# an extremely simple discriminator with 2 hidden nodes:
<- nn_sequential(
discriminator nn_linear(1, 2),
nn_sigmoid(),
nn_linear(2, 1),
nn_sigmoid()
)
# Let's try it! Should give a number between 0 and 1
as.numeric(discriminator(rnorm(1, 1, 3)))
[1] 0.5259537
Let’s label true data as 1 and fake data as 0 and use binary cross-entropy as the loss function for the discriminator. In other words, this discriminator is basically a (non-linear) logistic regression.
<- torch_ones_like(y_torch)
is_real <- torch_zeros_like(y_torch)
is_fake
<- nn_bce_loss() criterion
The training loop
After initializing the neural networks, they need to be trained. First, set up everything necessary to perform optimization:
# Two time-scale update rule: discriminator learning rate
# is twice as high as the generator learning rate
# https://arxiv.org/abs/1706.08500
<- optim_adam(generator$parameters, lr = 1e-2)
optg <- optim_adam(discriminator$parameters, lr = 2e-2)
optd
<- 500
n_epoch <- numeric(n_epoch)
dlosses <- numeric(n_epoch) glosses
Then, start the training loop, which also includes a nice plot:
# Plot theoretical density to learn
curve(
expr = dnorm(x, mean(y), sd(y)),
from = -10, to = 10,
ylim = c(0, dnorm(0)*2/3),
main = "Density approximation by GAN",
ylab = "Density",
xlab = ""
)
# Start training
for (i in 1:n_epoch) {
# generate fake data
<- torch_randn(N, 1, requires_grad = FALSE)
inp <- generator(inp)
y_fake
# train the discriminator
$zero_grad()
discriminator
# the discriminator loss is its ability to classify
# real and fake data correctly
<- discriminator(y_torch)
prob_real <- discriminator(y_fake)
prob_fake <- criterion(prob_real, is_real)
dloss_real <- criterion(prob_fake, is_fake)
dloss_fake <- dloss_real + dloss_fake
dloss
$backward()
dloss$step()
optd<- dloss$item()
dlosses[i]
# train the generator
$zero_grad()
generator
# the generator loss is its ability to create
# data that is classified by the discriminator
# as real data
<- discriminator(generator(inp))
prob_fake <- criterion(prob_fake, is_real)
gloss
$backward()
gloss$step()
optg<- gloss$item()
glosses[i]
# Print current state
if (interactive())
cat("\r iteration", i, "dloss:", dlosses[i], "gloss:", glosses[i])
# Plot current density estimate, assuming normal distribution
if (i %% (n_epoch/10) == 1) {
<- as.numeric(generator$parameters[["bias"]])
mu_hat <- abs(as.numeric(generator$parameters[["weight"]]))
sd_hat <- (n_epoch - i) / n_epoch * 0.5 + 0.5
c_val curve(dnorm(x, mu_hat, sd_hat), add = TRUE, col = rgb(c_val, c_val, c_val))
curve(dnorm(x, mean(y), sd(y)), add = TRUE)
} }
Then, it’s also possible to plot the losses for the discriminator and the generator, and see that they are indeed competing in a zero-sum game; as the generator loss goes down, the discriminator loss goes up (and vice versa):
par(mfrow = c(1, 2))
plot(glosses, type = "l", ylab = "loss", xlab = "Epoch",
main = "Generator loss", col = "darkblue", bty = "L")
plot(dlosses, type = "l", ylab = "loss", xlab = "Epoch",
main = "Discriminator loss", col = "darkred", bty = "L")
Lastly, inspect the parameters of the generator (the way we set it up, these should be the mean and sd of the target distribution). Then, also generate some fake data.
# inspect parameters
<- as.numeric(generator$parameters[["bias"]])) (mu_hat
[1] 0.9485998
<- abs(as.numeric(generator$parameters[["weight"]]))) (sd_hat
[1] 3.118304
# generate fake data
<- as.numeric(generator(matrix(rnorm(1000))))
y_hat
plot(density(y_hat), bty = "L", main = "Density of fake data y_hat")
curve(dnorm(x, mu_hat, sd_hat), add = TRUE, col = "darkgrey")
rug(y_hat)
Conclusion
In this post, I taught a generator to generate \(\mathcal{N}(1, 3)\) data by making it compete against a discriminator. The parameters of the generator were very close to the true values: 0.95 \(\approx\) 1 and 3.12 \(\approx\) 3. In other words, without having specified any likelihood, I have approximated maximum likelihood estimation for this distribution! For other distributions and with fewer assumptions, both the generator and the discriminator would need to be extended. The generator would need to do more complicated transformations to, say, generate uniform distributed data. The discriminator would also need to be more complicated to create nonlinear or multivariate decision boundaries.
In creating this smallest GAN, I’ve learned the following things:
- The process requires quite some fine-tuning to the problem at hand. I started with much larger networks for the same problem, and they were performing very poorly.
- The optimization needs tuning as well. In the end, I followed the two time-scale update rule which made a huge difference for the stability.
- The discriminator absolutely needs to have enough capacity to detect all the relevant features of the real distribution that you want to approximate. This reminds me of approximate Bayesian computation, where you need to select the correct statistics to optimize for. If you don’t do this, the result will be mode collapse: the generator will just generate the mean of the distribution and the discriminator will think this fits perfectly with the true distribution.