Creating New Anime Faces with Generative Adversarial Networks (GANs) using PyTorch

Sankalp Saoji
10 min readMay 6, 2021

--

Generative adversarial networks (GANs), in simple terms, are two neural network players competing against each other for the sole purpose of defeating the other. In this process, both of the neural networks learn a lot and develop more accurate results which can be used by us to create new, synthetic data. This new data is very close in resemblance to the original data used by the networks.

GANs are being used in image generation, voice generation and video generation. They fall under Generative Modeling which is an unsupervised learning task in ML involved in generating new content.

The generator generates a “fake” sample given a random vector/matrix, and the discriminator attempts to detect whether a given sample is “real” (picked from the training data) or “fake” (generated by the generator). Training happens in tandem: we train the discriminator for a few epochs, then train the generator for a few epochs, and repeat. This way both the generator and the discriminator get better at doing their jobs. GANs however, can be notoriously difficult to train, and are extremely sensitive to hyperparameters, activation functions and regularization.

In this tutorial, we’ll train a GAN to generate images of anime characters’ faces. We’ll use the Anime Face Dataset, which consists of over 63,000 cropped anime faces. There are no labels here as this is an unsupervised learning technique.

I have used Google Colab for coding the problem. So, I will show you how it is done step-wise.

Downloading the Data

We will download the data with the ‘opendatasets’ library. The dataset is located at ‘https://www.kaggle.com/splcher/animefacedataset’.

import opendatasets as oddataset_url = ‘https://www.kaggle.com/splcher/animefacedataset'od.download(dataset_url)

You will need to have a Kaggle account for downloading this dataset as the above command will ask for your Kaggle account ID and the API key. You can find the API key in the account details of your Kaggle account. Once you enter those, your dataset will be downloaded in Colab. The dataset has a single folder called images which contains all 63,000+ images in JPG format.

Exploring the Data

import osDATA_DIR = ‘./animefacedataset’print(os.listdir(DATA_DIR))print(os.listdir(DATA_DIR+’/images’)[:10])

You get the output as [‘14791_2006.jpg’, ‘15606_2006.jpg’, ‘37864_2012.jpg’, ‘61376_2018.jpg’, ‘51585_2015.jpg’, ‘1505_2001.jpg’, ‘42640_2013.jpg’, ‘54069_2016.jpg’, ‘48489_2014.jpg’, ‘388_2000.jpg’].

Importing Required Libraries

from torch.utils.data import DataLoaderfrom torchvision.datasets import ImageFolderimport torchvision.transforms as Timport torchfrom torchvision.utils import make_gridimport matplotlib.pyplot as pltimport torch.nn as nn%matplotlib inlinefrom torchvision.utils import save_imagefrom tqdm.notebook import tqdmimport torch.nn.functional as Ffrom IPython.display import Image

Creating the DataLoader

Let’s load this dataset using the ImageFolder class from torchvision. We will also resize and crop the images to 64x64 px, and normalize the pixel values with a mean & standard deviation of 0.5 for each channel. This will ensure that pixel values are in the range (-1, 1), which is more convenient for training the discriminator. We will also create a data loader to load the data in batches.

image_size = 64batch_size = 128stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)train_ds = ImageFolder(DATA_DIR, transform=T.Compose([T.Resize(image_size),T.CenterCrop(image_size),T.ToTensor(),T.Normalize(*stats)]))train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=3, pin_memory=True)

Creating Helper Functions for our Work

Function for De-normalization:

def denorm(img_tensors):return img_tensors * stats[1][0] + stats[0][0]

Function for Showing the Images:

def show_images(images, nmax=64):fig, ax = plt.subplots(figsize=(8, 8))ax.set_xticks([]); ax.set_yticks([])ax.imshow(make_grid(denorm(images.detach()[:nmax]), nrow=8).permute(1, 2, 0))

Function for Showing the Batch:

def show_batch(dl, nmax=64):for images, _ in dl:show_images(images, nmax)break

To check out the batch, we do,

show_batch(train_dl)

We can see that the batch looks like below.

Setting up the GPU

Below are the functions to be used for setting the GPU.

def get_default_device():“””Pick GPU if available, else CPU”””if torch.cuda.is_available():return torch.device(‘cuda’)else:return torch.device(‘cpu’)def to_device(data, device):“””Move tensor(s) to chosen device”””if isinstance(data, (list,tuple)):return [to_device(x, device) for x in data]return data.to(device, non_blocking=True)class DeviceDataLoader():“””Wrap a dataloader to move data to a device”””def __init__(self, dl, device):self.dl = dlself.device = devicedef __iter__(self):“””Yield a batch of data after moving it to device”””for b in self.dl:yield to_device(b, self.device)def __len__(self):“””Number of batches”””return len(self.dl)Based on where you’re running this notebook, your default device could be a CPU (torch.device('cpu')) or a GPU (torch.device('cuda'))device = get_default_device()train_dl = DeviceDataLoader(train_dl, device)

After this, we are now able to use the GPU.

Creating the Discriminator

The discriminator takes an image as input, and tries to classify it as “real” or “generated”. In this sense, it’s like any other neural network. We’ll use a convolutional neural networks (CNN) which outputs a single number output for every image. We’ll use stride of 2 to progressively reduce the size of the output feature map.

We create the discriminator as follows.

discriminator = nn.Sequential(# in: 3 x 64 x 64nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(64),nn.LeakyReLU(0.2, inplace=True),# out: 64 x 32 x 32nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace=True),# out: 128 x 16 x 16nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(256),nn.LeakyReLU(0.2, inplace=True),# out: 256 x 8 x 8nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(512),nn.LeakyReLU(0.2, inplace=True),# out: 512 x 4 x 4nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),# out: 1 x 1 x 1nn.Flatten(),nn.Sigmoid())

We’re using the Leaky ReLU activation for the discriminator. Different from the regular ReLU function, Leaky ReLU allows the pass of a small gradient signal for negative values. As a result, it makes the gradients from the discriminator flows stronger into the generator. Instead of passing a gradient (slope) of 0 in the back-prop pass, it passes a small negative gradient. Just like any other binary classification model, the output of the discriminator is a single number between 0 and 1, which can be interpreted as the probability of the input image being real i.e. picked from the original dataset.

Let’s move the discriminator model to the chosen device.

discriminator = to_device(discriminator, device)

Creating the Generator

The input to the generator is typically a vector or a matrix of random numbers (referred to as a latent tensor) which is used as a seed for generating an image. The generator will convert a latent tensor of shape (128, 1, 1) into an image tensor of shape 3 x 28 x 28. To achive this, we'll use the ConvTranspose2d layer from PyTorch.

latent_size = 128generator = nn.Sequential(# in: latent_size x 1 x 1nn.ConvTranspose2d(latent_size, 512, kernel_size=4, stride=1, padding=0, bias=False),nn.BatchNorm2d(512),nn.ReLU(True),# out: 512 x 4 x 4nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(True),# out: 256 x 8 x 8nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.ReLU(True),# out: 128 x 16 x 16nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(64),nn.ReLU(True),# out: 64 x 32 x 32nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),nn.Tanh()# out: 3 x 64 x 64)

The ReLU activation is used in the generator with the exception of the output layer which uses the Tanh function. We observed that using a bounded activation allowed the model to learn more quickly to saturate and cover the color space of the training distribution. Within the discriminator we found the leaky rectified activation to work well, especially for higher resolution modeling.

Note that since the outputs of the TanH activation lie in the range [-1,1], we have applied the similar transformation to the images in the training dataset. Let's generate some outputs using the generator and view them as images by transforming and denormalizing the output.

xb = torch.randn(batch_size, latent_size, 1, 1) # random latent tensorsfake_images = generator(xb)print(fake_images.shape)show_images(fake_images)

As one might expect, the output from the generator is basically random noise, since we haven’t trained it yet.

Let’s move the generator to the chosen device.

generator = to_device(generator, device)

Training the Discriminator

Since the discriminator is a binary classification model, we can use the binary cross entropy loss function to quantify how well it is able to differentiate between real and generated images.

def train_discriminator(real_images, opt_d):# Clear discriminator gradientsopt_d.zero_grad()# Pass real images through discriminatorreal_preds = discriminator(real_images)real_targets = torch.ones(real_images.size(0), 1, device=device)real_loss = F.binary_cross_entropy(real_preds, real_targets)real_score = torch.mean(real_preds).item()# Generate fake imageslatent = torch.randn(batch_size, latent_size, 1, 1, device=device)fake_images = generator(latent)# Pass fake images through discriminatorfake_targets = torch.zeros(fake_images.size(0), 1, device=device)fake_preds = discriminator(fake_images)fake_loss = F.binary_cross_entropy(fake_preds, fake_targets)fake_score = torch.mean(fake_preds).item()# Update discriminator weightsloss = real_loss + fake_lossloss.backward()opt_d.step()return loss.item(), real_score, fake_score

Here are the steps involved in training the discriminator.

  • We expect the discriminator to output 1 if the image was picked from the real MNIST dataset, and 0 if it was generated using the generator network.
  • We first pass a batch of real images, and compute the loss, setting the target labels to 1.
  • Then we pass a batch of fake images (generated using the generator) pass them into the discriminator, and compute the loss, setting the target labels to 0.
  • Finally we add the two losses and use the overall loss to perform gradient descent to adjust the weights of the discriminator.

It’s important to note that we don’t change the weights of the generator model while training the discriminator (opt_d only affects the discriminator.parameters())

Training the Generator

Since the outputs of the generator are images, it’s not obvious how we can train the generator. This is where we employ a rather elegant trick, which is to use the discriminator as a part of the loss function. Here’s how it works:

  • We generate a batch of images using the generator, pass the into the discriminator.
  • We calculate the loss by setting the target labels to 1 i.e. real. We do this because the generator’s objective is to “fool” the discriminator.
  • We use the loss to perform gradient descent i.e. change the weights of the generator, so it gets better at generating real-like images to “fool” the discriminator.

Here’s what this looks like in code.

def train_generator(opt_g):# Clear generator gradientsopt_g.zero_grad()# Generate fake imageslatent = torch.randn(batch_size, latent_size, 1, 1, device=device)fake_images = generator(latent)# Try to fool the discriminatorpreds = discriminator(fake_images)targets = torch.ones(batch_size, 1, device=device)loss = F.binary_cross_entropy(preds, targets)# Update generator weightsloss.backward()opt_g.step()return loss.item()

Now, let’s create a directory to save the images.

sample_dir = ‘generated’os.makedirs(sample_dir, exist_ok=True)

Let’s also create a helper function to export the generated images.

def save_samples(index, latent_tensors, show=True):fake_images = generator(latent_tensors)fake_fname = ‘generated-images-{0:0=4d}.png’.format(index)save_image(denorm(fake_images), os.path.join(sample_dir, fake_fname), nrow=8)print(‘Saving’, fake_fname)if show:fig, ax = plt.subplots(figsize=(8, 8))ax.set_xticks([]); ax.set_yticks([])ax.imshow(make_grid(fake_images.cpu().detach(), nrow=8).permute(1, 2, 0))

We’ll use a fixed set of input vectors to the generator to see how the individual generated images evolve over time as we train the model. Let’s save one set of images before we start training our model.

fixed_latent = torch.randn(64, latent_size, 1, 1, device=device)save_samples(0, fixed_latent)

The generated images look like below.

Let’s define a fit function to train the discriminator and generator in tandem for each batch of training data. We'll use the Adam optimizer with some custom parameters (betas) that are known to work well for GANs. We will also save some sample generated images at regular intervals for inspection.

def fit(epochs, lr, start_idx=1):torch.cuda.empty_cache()# Losses & scoreslosses_g = []losses_d = []real_scores = []fake_scores = []# Create optimizersopt_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))opt_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))for epoch in range(epochs):for real_images, _ in tqdm(train_dl):# Train discriminatorloss_d, real_score, fake_score = train_discriminator(real_images, opt_d)# Train generatorloss_g = train_generator(opt_g)# Record losses & scoreslosses_g.append(loss_g)losses_d.append(loss_d)real_scores.append(real_score)fake_scores.append(fake_score)# Log losses & scores (last batch)print(“Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}”.format(epoch+1, epochs, loss_g, loss_d, real_score, fake_score))# Save generated imagessave_samples(epoch+start_idx, fixed_latent, show=False)return losses_g, losses_d, real_scores, fake_scores

Training the Model

lr = 0.0002epochs = 25

We choose learning rate 0.0002 and number of epochs as 25 for the model. But, you can choose any combination and see how the model works for yourself.

history = fit(epochs, lr)

At the last epoch, we get,

Epoch [25/25], loss_g: 2.4400, loss_d: 0.1189, real_score: 0.9247, fake_score: 0.0280 Saving generated-images-0025.png

Saving the Model

torch.save(generator.state_dict(), ‘G.pth’)torch.save(discriminator.state_dict(), ‘D.pth’)

Viewing the Progress

As we see the model gets better and in the end, we obtain previously unseen anime faces that look like the real thing!

Plotting the Graphs

plt.plot(losses_d, ‘-’)plt.plot(losses_g, ‘-’)plt.xlabel(‘epoch’)plt.ylabel(‘loss’)plt.legend([‘Discriminator’, ‘Generator’])plt.title(‘Losses’);
plt.plot(real_scores, ‘-’)plt.plot(fake_scores, ‘-’)plt.xlabel(‘epoch’)plt.ylabel(‘score’)plt.legend([‘Real’, ‘Fake’])plt.title(‘Scores’);

So, this is it! Hope this article provides you the motivation to try out different ideas with GANs.

--

--

No responses yet