Introduction to Generative Adversarial Networks
Generative Adversarial Network (GANs) are a class of machine learning frameworks designed to generate new data instances that resemble the training data. Introduced by Ian Goodfellow in 2014, the architecture consists of two distinct neural networks that compete against each other in a game-theoretic scenario:
- The Generator ($G$): This network takes a random noise vector (latent code) as input and attempts to produce synthetic data samples that look indistinguishable from real data.
- The Discriminator ($D$): This network acts as a binary classifier. It receives either real data from the training set or fake data from the Generator and tries to determine the source of the input.
The training objective is a minimax game where the Generator tries to fool the Discriminator, while the Discriminator tries to correct classify real versus fake samples. Ideally, this leads to a Nash equilibrium where the Generator produces perfect samples and the Discriminator's accuracy is 50% (random guessing).
In mathematical terms, $D(x)$ represents the probability that data $x$ is real. The Generator $G(z)$ maps noise $z$ (from a standard normal distribution) to the data space. The optimization involves maximizing the probability of $D$ correct classifying real and fake samples for $D$, and minimizing $\log(1 - D(G(z)))$ for $G$.
Environment Setup
To implement this solution, we use Python 3.9 and MindSpore 2.2.14. Install the framework using the following command:
pip install mindspore==2.2.14
Data Preparation
We utilize the MNIST dataset, which contains 60,000 training images and 10,000 test images of handwritten digits (28x28 pixels, grayscale).
Downloading and Loading Data
The following code downloads the dataset and constructs a data pipeline. We apply normalization and batch the data for efficient training.
import numpy as np
import mindspore.dataset as ds
from download import download
# Download dataset
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
download(url, ".", kind="zip", replace=True)
# Configuration
batch_sz = 64
noise_dim = 100 # Dimension of the latent vector
def create_dataset(dataset_path):
# Load the raw MNIST data
dataset = ds.MnistDataset(dataset_dir=dataset_path)
# Processing function
def process_img(image, label):
image = image.astype("float32")
# Normalize image to [-1, 1] range
image = (image - 127.5) / 127.5
# Generate random noise for the generator input
noise = np.random.normal(size=noise_dim).astype("float32")
return image, noise
# Apply transformations and batching
dataset = dataset.map(process_img, output_columns=["image", "noise"])
dataset = dataset.batch(batch_sz, drop_remainder=True)
return dataset
train_ds = create_dataset('./MNIST_Data/train')
print(f"Iterations per epoch: {train_ds.get_dataset_size()}")
Visualizing Data
We can visualize a batch of the training data to confirm our preprocessing pipeline.
import matplotlib.pyplot as plt
data_iter = next(train_ds.create_dict_iterator(output_numpy=True))
fig = plt.figure(figsize=(3, 3))
cols, rows = 5, 5
for i in range(1, cols * rows + 1):
img = data_iter['image'][i]
fig.add_subplot(rows, cols, i)
plt.axis("off")
# Display image rescaled back to [0, 1] for viewing
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
Fixed Latent Vectors
To monitor the Generator's progress, we use a fixed set of noise vectors throughout the training process.
from mindspore import Tensor
import mindspore.common.dtype as mstype
np.random.seed(42)
fixed_noise = Tensor(np.random.normal(0, 1, (25, noise_dim)), mstype.float32)
Model Architecture
For the MNIST dataset, a fully connected architecture is sufficient. We avoid complex convolutional layers for this demonstration to focus on the core GAN mechanics.
Generator Network
The Generator maps a 100-dimensional noise vector to a 784-dimensional vector (28x28), reshaped into an image. It uses ReLU activations, Batch Normalization, and a Tanh output function.
import mindspore.nn as nn
import mindspore.ops as ops
class DigitGenerator(nn.Cell):
def __init__(self, latent_dim):
super(DigitGenerator, self).__init__()
self.model = nn.SequentialCell(
nn.Dense(latent_dim, 128),
nn.ReLU(),
nn.Dense(128, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dense(256, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dense(512, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Dense(1024, 28 * 28),
nn.Tanh() # Output range [-1, 1]
)
def construct(self, x):
output = self.model(x)
return ops.reshape(output, (-1, 1, 28, 28))
# Initialize Generator
generator = DigitGenerator(noise_dim)
generator.update_parameters_name('g_')
Discriminator Network
The Discriminator takes a 28x28 image and outputs a single scalar representing the probability that the image is real. It uses LeakyReLU activations.
class DigitDiscriminator(nn.Cell):
def __init__(self):
super(DigitDiscriminator, self).__init__()
self.model = nn.SequentialCell(
nn.Dense(28 * 28, 512),
nn.LeakyReLU(alpha=0.2),
nn.Dense(512, 256),
nn.LeakyReLU(alpha=0.2),
nn.Dense(256, 1),
nn.Sigmoid() # Output range [0, 1]
)
def construct(self, x):
x_flat = ops.reshape(x, (-1, 28 * 28))
return self.model(x_flat)
# Initialize Discriminator
discriminator = DigitDiscriminator()
discriminator.update_parameters_name('d_')
Training Configuration
We use Binary Cross Entropy (BCE) Loss and the Adam optimizer for both networks. The learning rate is set to 0.0002.
# Loss Function
loss_fn = nn.BCELoss(reduction='mean')
# Optimizers
lr = 0.0002
opt_g = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
opt_d = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
# Update optimizer parameter names to avoid conflicts
opt_g.update_parameters_name('adam_g')
opt_d.update_parameters_name('adam_d')
Training Loop
The training process involves alternating between updating the Discriminator and the Generator.
- Update Discriminator: Train on real images (label 1) and fake images (label 0).
- Update Generator: Generate fake images and attempt to get the Discriminator to classify them as real (label 1).
import os
import time
import mindspore as ms
# Gradient functions
grad_d_fn = ms.value_and_grad(
lambda real_imgs, latent_z: loss_fn(discriminator(real_imgs), ops.ones_like(discriminator(real_imgs))) +
loss_fn(discriminator(generator(latent_z)), ops.zeros_like(discriminator(generator(latent_z)))),
None,
discriminator.trainable_params()
)
grad_g_fn = ms.value_and_grad(
lambda latent_z: loss_fn(discriminator(generator(latent_z)), ops.ones_like(discriminator(generator(latent_z)))),
None,
generator.trainable_params()
)
def train_step(real_imgs, latent_z):
# Train Discriminator
d_loss, d_grads = grad_d_fn(real_imgs, latent_z)
opt_d(d_grads)
# Train Generator
g_loss, g_grads = grad_g_fn(latent_z)
opt_g(g_grads)
return d_loss, g_loss
# Helper to save images
def save_generated_images(gen_imgs, epoch_idx):
gen_imgs = gen_imgs.asnumpy()
fig = plt.figure(figsize=(3, 3))
for i in range(25):
plt.subplot(5, 5, i + 1)
plt.imshow(gen_imgs[i, 0, :, :] * 0.5 + 0.5, cmap="gray")
plt.axis("off")
plt.savefig(f"./results/epoch_{epoch_idx}.png")
plt.close()
# Directories
os.makedirs("./results", exist_ok=True)
# Training parameters
epochs = 12
generator.set_train()
discriminator.set_train()
loss_history_g, loss_history_d = [], []
for epoch in range(epochs):
start_time = time.time()
for idx, data in enumerate(train_ds):
real_imgs, latent_z = data
d_loss, g_loss = train_step(real_imgs, latent_z)
if idx % 100 == 0:
print(f"Epoch: [{epoch}/{epochs}], Step: [{idx}/{train_ds.get_dataset_size()}], "
f"D Loss: {d_loss.asnumpy():.4f}, G Loss: {g_loss.asnumpy():.4f}")
# Record loss
loss_history_d.append(d_loss.asnumpy())
loss_history_g.append(g_loss.asnumpy())
# Generate and save visualization
fake_imgs = generator(fixed_noise)
save_generated_images(fake_imgs, epoch)
# Save checkpoints
if epoch % 1 == 0:
ms.save_checkpoint(generator, f"./results/generator_{epoch}.ckpt")
ms.save_checkpoint(discriminator, f"./results/discriminator_{epoch}.ckpt")
print(f"Time for epoch {epoch}: {time.time() - start_time:.2f}s")
Results and Visualization
Plotting the loss curves helps visualize the convergence of the model.
plt.figure(figsize=(6, 4))
plt.plot(loss_history_g, label="Generator Loss")
plt.plot(loss_history_d, label="Discriminator Loss")
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()
Inference with Trained Model
Finally, we can load a trained checkpoint to generate new digits.
# Load weights (uncomment to use saved checkpoint)
# ckpt_path = "./results/generator_11.ckpt"
# param_dict = ms.load_checkpoint(ckpt_path)
# ms.load_param_into_net(generator, param_dict)
# Generate new random noise
inference_noise = Tensor(np.random.normal(0, 1, (25, noise_dim)), mstype.float32)
generated_images = generator(inference_noise).transpose(0, 2, 3, 1).asnumpy()
# Visualize
fig = plt.figure(figsize=(3, 3), dpi=120)
for i in range(25):
fig.add_subplot(5, 5, i + 1)
plt.axis("off")
plt.imshow(generated_images[i].squeeze(), cmap="gray")
plt.show()