Model Architecture
CycleGAN (Cyclic Generative Adverasrial Network) implements cyclic-consistent adversarial networks from the paper "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks". This approach enables learning image transformation between source domain X and target domain Y without requiring paired training examples.
The primary application involves domain adaptation, commonly understood as image style transfer. While previous models like Pix2Pix required paired training data, CycleGAN operates with unpaired datasets from two domains, making it suitable for unsupervised image translation scenarios where finding corresponding pairs across different styles is impractical.
Dataset Preparation
The dataset originates from ImageNet with 17 packages, focusing specifically on apple-to-orange transformations. Images are standardized to 256×256 pixels. Training includes 996 apple images and 1020 orange images, with 266 apple and 248 orange images reserved for testing.
Data preprocessing involves random cropping, horizontal flipping, and normalization. Preprocessed results are converted to MindRecord format to streamline data handling operations.
# 1. Download dataset
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip"
download(url, ".", kind="zip", replace=True)
# 2. Load dataset
from mindspore.dataset import MindDataset
data_path = "./CycleGAN_apple2orange/apple2orange_train.mindrecord"
training_data = MindDataset(dataset_files=data_path)
print("Dataset size: ", training_data.get_dataset_size())
batch_size = 1
processed_dataset = training_data.batch(batch_size)
dataset_size = processed_dataset.get_dataset_size()
# 3. Visualize dataset
import numpy as np
import matplotlib.pyplot as plt
mean_val = 0.5 * 255
std_val = 0.5 * 255
plt.figure(figsize=(12, 5), dpi=60)
for idx, batch in enumerate(processed_dataset.create_dict_iterator()):
if idx < 5:
apple_images = batch["image_A"].asnumpy()
orange_images = batch["image_B"].asnumpy()
plt.subplot(2, 5, idx+1)
apple_processed = (apple_images[0] * std_val + mean_val).astype(np.uint8).transpose((1, 2, 0))
plt.imshow(apple_processed)
plt.axis("off")
plt.subplot(2, 5, idx+6)
orange_processed = (orange_images[0] * std_val + mean_val).astype(np.uint8).transpose((1, 2, 0))
plt.imshow(orange_processed)
plt.axis("off")
else:
break
plt.show()
Generator Construction
The generator architecture follows ResNet principles. For 128×128 inputs, six residual blocks are used; for 256×256 inputs, nine residual blocks are employed. The hyperparameter n_layers controls the number of residual blocks.
The architecture consists of encoder-decoder with skip connecitons:
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Normal
initializer_weights = Normal(sigma=0.02)
class ConvolutionNormRelu(nn.Cell):
def __init__(self, input_ch, output_ch, kernel=4, stride=2, leak_alpha=0.2, norm_type='instance',
padding_mode='CONSTANT', activate=True, pad_amount=None, transposed=False):
super(ConvolutionNormRelu, self).__init__()
normalizer = nn.BatchNorm2d(output_ch)
if norm_type == 'instance':
normalizer = nn.BatchNorm2d(output_ch, affine=False)
bias_enabled = (norm_type == 'instance')
if pad_amount is None:
pad_amount = (kernel - 1) // 2
if padding_mode == 'CONSTANT':
if transposed:
convolution = nn.Conv2dTranspose(input_ch, output_ch, kernel, stride, pad_mode='same',
has_bias=bias_enabled, weight_init=initializer_weights)
else:
convolution = nn.Conv2d(input_ch, output_ch, kernel, stride, pad_mode='pad',
has_bias=bias_enabled, padding=pad_amount, weight_init=initializer_weights)
layers = [convolution, normalizer]
else:
padding_values = ((0, 0), (0, 0), (pad_amount, pad_amount), (pad_amount, pad_amount))
pad_layer = nn.Pad(paddings=padding_values, mode=padding_mode)
if transposed:
convolution = nn.Conv2dTranspose(input_ch, output_ch, kernel, stride, pad_mode='pad',
has_bias=bias_enabled, weight_init=initializer_weights)
else:
convolution = nn.Conv2d(input_ch, output_ch, kernel, stride, pad_mode='pad',
has_bias=bias_enabled, weight_init=initializer_weights)
layers = [pad_layer, convolution, normalizer]
if activate:
activation = nn.ReLU()
if leak_alpha > 0:
activation = nn.LeakyReLU(leak_alpha)
layers.append(activation)
self.layer_stack = nn.SequentialCell(layers)
def construct(self, x):
result = self.layer_stack(x)
return result
class ResidualUnit(nn.Cell):
def __init__(self, dimensions, norm_style='instance', apply_dropout=False, pad_style="CONSTANT"):
super(ResidualUnit, self).__init__()
self.conv_block1 = ConvolutionNormRelu(dimensions, dimensions, 3, 1, 0, norm_style, pad_style)
self.conv_block2 = ConvolutionNormRelu(dimensions, dimensions, 3, 1, 0, norm_style, pad_style, use_relu=False)
self.dropout_enabled = apply_dropout
if apply_dropout:
self.dropout_layer = nn.Dropout(p=0.5)
def construct(self, x):
intermediate = self.conv_block1(x)
if self.dropout_enabled:
intermediate = self.dropout_layer(intermediate)
output = self.conv_block2(intermediate)
return x + output
class ResNetTranslator(nn.Cell):
def __init__(self, input_ch=3, base_ch=64, block_count=9, leak_alpha=0.2, norm_style='instance', use_dropout=False,
pad_style="CONSTANT"):
super(ResNetTranslator, self).__init__()
self.encoder_input = ConvolutionNormRelu(input_ch, base_ch, 7, 1, leak_alpha, norm_style, pad_mode=pad_style)
self.downsample1 = ConvolutionNormRelu(base_ch, base_ch * 2, 3, 2, leak_alpha, norm_style)
self.downsample2 = ConvolutionNormRelu(base_ch * 2, base_ch * 4, 3, 2, leak_alpha, norm_style)
residual_blocks = [ResidualUnit(base_ch * 4, norm_style, dropout=use_dropout, pad_mode=pad_style)] * block_count
self.residual_stack = nn.SequentialCell(residual_blocks)
self.upsample2 = ConvolutionNormRelu(base_ch * 4, base_ch * 2, 3, 2, leak_alpha, norm_style, transpose=True)
self.upsample1 = ConvolutionNormRelu(base_ch * 2, base_ch, 3, 2, leak_alpha, norm_style, transpose=True)
if pad_style == "CONSTANT":
self.decoder_output = nn.Conv2d(base_ch, 3, kernel_size=7, stride=1, pad_mode='pad',
padding=3, weight_init=initializer_weights)
else:
pad_layer = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_style)
conv_layer = nn.Conv2d(base_ch, 3, kernel_size=7, stride=1, pad_mode='pad', weight_init=initializer_weights)
self.decoder_output = nn.SequentialCell([pad_layer, conv_layer])
def construct(self, x):
x = self.encoder_input(x)
x = self.downsample1(x)
x = self.downsample2(x)
x = self.residual_stack(x)
x = self.upsample2(x)
x = self.upsample1(x)
output = self.decoder_output(x)
return ops.tanh(output)
# Initialize generators
translator_x_to_y = ResNetTranslator()
translator_x_to_y.update_parameters_name('translator_x_to_y.')
translator_y_to_x = ResNetTranslator()
translator_y_to_x.update_parameters_name('translator_y_to_x.')
Discriminator Construction
The discriminator serves as a binary classifier determining real vs. generated image probability. It uses PatchGAN with 70×70 receptive fields through sequential Conv2d, BatchNorm2d, and LeakyReLU layers, concluding with sigmoid activation.
# Define discriminator
class ImageClassifier(nn.Cell):
def __init__(self, input_ch=3, base_ch=64, layer_count=3, leak_alpha=0.2, norm_style='instance'):
super(ImageClassifier, self).__init__()
kernel_size = 4
layers = [nn.Conv2d(input_ch, base_ch, kernel_size, 2, pad_mode='pad', padding=1, weight_init=initializer_weights),
nn.LeakyReLU(leak_alpha)]
multiplier = base_ch
for i in range(1, layer_count):
prev_multiplier = multiplier
multiplier = min(2 ** i, 8) * base_ch
layers.append(ConvolutionNormRelu(prev_multiplier, multiplier, kernel_size, 2, leak_alpha, norm_style, padding=1))
prev_multiplier = multiplier
multiplier = min(2 ** layer_count, 8) * base_ch
layers.append(ConvolutionNormRelu(prev_multiplier, multiplier, kernel_size, 1, leak_alpha, norm_style, padding=1))
layers.append(nn.Conv2d(multiplier, 1, kernel_size, 1, pad_mode='pad', padding=1, weight_init=initializer_weights))
self.feature_extractor = nn.SequentialCell(layers)
def construct(self, x):
output = self.feature_extractor(x)
return output
# Initialize discriminators
discriminator_domain_x = ImageClassifier()
discriminator_domain_x.update_parameters_name('discriminator_domain_x.')
discriminator_domain_y = ImageClassifier()
discriminator_domain_y.update_parameters_name('discriminator_domain_y.')
Optimizers and Loss Functions
Separate optimizers are configured for each model component based on training requirements.
For generator G and its corresponding discriminator D_Y, the adversarial loss function is defined as:
L_GAN(G, D_Y, X, Y) = E[yp_data(y)][log D_Y(y)] + E[xp_data(x)][log(1 - D_Y(G(x)))]
The generator attempts to create images G(x) similar to domain Y, while D_Y distinguishes between translated samples G(x) and real samples y. The optimization goal is min_G max_D_Y L_GAN(G, D_Y, X, Y).
Adversarial loss alone cannot guarantee proper input-output mapping. To constrain the mapping function space, cycle consistency is enforced. For every image x in domain X, the transformation cycle should return x: x → G(x) → F(G(x)) ≈ x. Similarly for domain Y: y → F(y) → G(F(y)) ≈ y.
The cycle consistency loss function is:
L_cyc(G, F) = E[xp_data(x)][||F(G(x)) - x||_1] + E[yp_data(y)][||G(F(y)) - y||_1]
# Setup optimizers
optimizer_gen_x_to_y = nn.Adam(translator_x_to_y.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_gen_y_to_x = nn.Adam(translator_y_to_x.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_disc_x = nn.Adam(discriminator_domain_x.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_disc_y = nn.Adam(discriminator_domain_y.trainable_params(), learning_rate=0.0002, beta1=0.5)
# Loss functions
mse_loss = nn.MSELoss(reduction='mean')
l1_distance = nn.L1Loss("mean")
def adversarial_loss(prediction, truth_label):
target_tensor = ops.ones_like(prediction) * truth_label
calculated_loss = mse_loss(prediction, target_tensor)
return calculated_loss
Forward Pass Implementation
To reduce model oscillation, historical generated images rather than the latest ones are used to update discriminators. An image buffer stores the 50 most recent generated images.
import mindspore as ms
# Forward computation
def translate_images(source_img, target_img):
synthetic_source = translator_y_to_x(target_img)
synthetic_target = translator_x_to_y(source_img)
reconstructed_source = translator_y_to_x(synthetic_target)
reconstructed_target = translator_x_to_y(synthetic_source)
identity_source = translator_y_to_x(source_img)
identity_target = translator_x_to_y(target_img)
return synthetic_source, synthetic_target, reconstructed_source, reconstructed_target, identity_source, identity_target
weight_a = 10.0
weight_b = 10.0
identity_weight = 0.5
def generator_computation(src_img, tgt_img):
true_tensor = Tensor(True, dtype.bool_)
fake_src, fake_tgt, rec_src, rec_tgt, id_src, id_tgt = translate_images(src_img, tgt_img)
loss_gen_src = adversarial_loss(discriminator_domain_y(fake_tgt), true_tensor)
loss_gen_tgt = adversarial_loss(discriminator_domain_x(fake_src), true_tensor)
loss_cycle_src = l1_distance(rec_src, src_img) * weight_a
loss_cycle_tgt = l1_distance(rec_tgt, tgt_img) * weight_b
loss_identity_src = l1_distance(id_src, src_img) * weight_a * identity_weight
loss_identity_tgt = l1_distance(id_tgt, tgt_img) * weight_b * identity_weight
total_loss_gen = loss_gen_src + loss_gen_tgt + loss_cycle_src + loss_cycle_tgt + loss_identity_src + loss_identity_tgt
return fake_src, fake_tgt, total_loss_gen, loss_gen_src, loss_gen_tgt, loss_cycle_src, loss_cycle_tgt, loss_identity_src, loss_identity_tgt
def generator_gradient_computation(src_img, tgt_img):
_, _, total_loss, _, _, _, _, _, _ = generator_computation(src_img, tgt_img)
return total_loss
def discriminator_computation(src_img, tgt_img, gen_src, gen_tgt):
false_tensor = Tensor(False, dtype.bool_)
true_tensor = Tensor(True, dtype.bool_)
disc_fake_src = discriminator_domain_x(gen_src)
disc_real_src = discriminator_domain_x(src_img)
disc_fake_tgt = discriminator_domain_y(gen_tgt)
disc_real_tgt = discriminator_domain_y(tgt_img)
loss_disc_src = adversarial_loss(disc_fake_src, false_tensor) + adversarial_loss(disc_real_src, true_tensor)
loss_disc_tgt = adversarial_loss(disc_fake_tgt, false_tensor) + adversarial_loss(disc_real_tgt, true_tensor)
combined_disc_loss = (loss_disc_src + loss_disc_tgt) * 0.5
return combined_disc_loss
def discriminator_x_computation(real_src, generated_src):
false_tensor = Tensor(False, dtype.bool_)
true_tensor = Tensor(True, dtype.bool_)
disc_gen_src = discriminator_domain_x(generated_src)
disc_real_src = discriminator_domain_x(real_src)
loss_disc_src = adversarial_loss(disc_gen_src, false_tensor) + adversarial_loss(disc_real_src, true_tensor)
return loss_disc_src
def discriminator_y_computation(real_tgt, generated_tgt):
false_tensor = Tensor(False, dtype.bool_)
true_tensor = Tensor(True, dtype.bool_)
disc_gen_tgt = discriminator_domain_y(generated_tgt)
disc_real_tgt = discriminator_domain_y(real_tgt)
loss_disc_tgt = adversarial_loss(disc_gen_tgt, false_tensor) + adversarial_loss(disc_real_tgt, true_tensor)
return loss_disc_tgt
# Image buffer to store previous 50 generated images
buffer_capacity = 50
def history_buffer(images):
count = 0
stored_images = []
if isinstance(images, Tensor):
images = images.asnumpy()
result_images = []
for image in images:
if count < buffer_capacity:
count += 1
stored_images.append(image)
result_images.append(image)
else:
if random.uniform(0, 1) > 0.5:
random_index = random.randint(0, buffer_capacity - 1)
temp = stored_images[random_index].copy()
stored_images[random_index] = image
result_images.append(temp)
else:
result_images.append(image)
output = Tensor(result_images, ms.float32)
if output.ndim != 4:
raise ValueError("img should be 4d, but get shape {}".format(output.shape))
return output
Gradient Computation and Backpropagation
from mindspore import value_and_grad
# Gradient computation methods
gradient_gen_x = value_and_grad(generator_gradient_computation, None, translator_x_to_y.trainable_params())
gradient_gen_y = value_and_grad(generator_gradient_computation, None, translator_y_to_x.trainable_params())
gradient_disc_x = value_and_grad(discriminator_x_computation, None, discriminator_domain_x.trainable_params())
gradient_disc_y = value_and_grad(discriminator_y_computation, None, discriminator_domain_y.trainable_params())
# Generator training step
def train_generator_step(img_src, img_tgt):
discriminator_domain_x.set_grad(False)
discriminator_domain_y.set_grad(False)
fake_src, fake_tgt, gen_total_loss, gen_loss_x, gen_loss_y, cyc_loss_x, cyc_loss_y, id_loss_x, id_loss_y = generator_computation(img_src, img_tgt)
_, gradients_gen_x = gradient_gen_x(img_src, img_tgt)
_, gradients_gen_y = gradient_gen_y(img_src, img_tgt)
optimizer_gen_x_to_y(gradients_gen_x)
optimizer_gen_y_to_x(gradients_gen_y)
return fake_src, fake_tgt, gen_total_loss, gen_loss_x, gen_loss_y, cyc_loss_x, cyc_loss_y, id_loss_x, id_loss_y
# Discriminator training step
def train_discriminator_step(img_src, img_tgt, fake_src, fake_tgt):
discriminator_domain_x.set_grad(True)
discriminator_domain_y.set_grad(True)
loss_disc_x, grads_disc_x = gradient_disc_x(img_src, fake_src)
loss_disc_y, grads_disc_y = gradient_disc_y(img_tgt, fake_tgt)
combined_loss = (loss_disc_x + loss_disc_y) * 0.5
optimizer_disc_x(grads_disc_x)
optimizer_disc_y(grads_disc_y)
return combined_loss
Model Training
Training alternates between discriminator and generator updates. The approach uses least squares loss instead of negative log-likelihood objectives.
- Discriminator training: Maximize probability of distinguishing real from fake images, minimizing E[y~p_data(y)][(D(y)-1)^2]
- Generator training: Minimize E[x~p_data(x)][(D(G(x))-1)^2] to produce more realistic fake images
import os
import time
import random
import numpy as np
from PIL import Image
from mindspore import Tensor, save_checkpoint
from mindspore import dtype
# Training configuration
num_epochs = 1
checkpoint_interval = 80
save_frequency = 1
checkpoint_directory = './train_ckpt_outputs/'
print('Starting training!')
for epoch in range(num_epochs):
generator_losses = []
discriminator_losses = []
epoch_start_time = time.time()
for step, batch_data in enumerate(processed_dataset.create_dict_iterator()):
step_start_time = time.time()
source_image = batch_data["image_A"]
target_image = batch_data["image_B"]
generator_result = train_generator_step(source_image, target_image)
generated_src = generator_result[0]
generated_tgt = generator_result[1]
discriminator_loss = train_discriminator_step(source_image, target_image, history_buffer(generated_src), history_buffer(generated_tgt))
step_disc_loss = float(discriminator_loss.asnumpy())
step_duration = time.time() - step_start_time
processed_results = []
for item in generator_result[2:]:
processed_results.append(float(item.asnumpy()))
generator_losses.append(processed_results[0])
discriminator_losses.append(step_disc_loss)
if step % checkpoint_interval == 0:
print(f"Epoch:[{int(epoch + 1):>3d}/{int(num_epochs):>3d}], "
f"step:[{int(step):>4d}/{int(dataset_size):>4d}], "
f"time:{step_duration:>3f}s,\n"
f"loss_g:{processed_results[0]:.2f}, loss_d:{step_disc_loss:.2f}, "
f"loss_g_a: {processed_results[1]:.2f}, loss_g_b: {processed_results[2]:.2f}, "
f"loss_c_a: {processed_results[3]:.2f}, loss_c_b: {processed_results[4]:.2f}, "
f"loss_idt_a: {processed_results[5]:.2f}, loss_idt_b: {processed_results[6]:.2f}")
epoch_duration = time.time() - epoch_start_time
average_step_time = epoch_duration / dataset_size
avg_disc_loss, avg_gen_loss = sum(discriminator_losses) / dataset_size, sum(generator_losses) / dataset_size
print(f"Epoch:[{int(epoch + 1):>3d}/{int(num_epochs):>3d}], "
f"epoch time:{epoch_duration:.2f}s, per step time:{average_step_time:.2f}, "
f"mean_g_loss:{avg_gen_loss:.2f}, mean_d_loss:{avg_disc_loss :.2f}")
if epoch % save_frequency == 0:
os.makedirs(checkpoint_directory, exist_ok=True)
save_checkpoint(translator_x_to_y, os.path.join(checkpoint_directory, f"g_x_{epoch}.ckpt"))
save_checkpoint(translator_y_to_x, os.path.join(checkpoint_directory, f"g_y_{epoch}.ckpt"))
save_checkpoint(discriminator_domain_x, os.path.join(checkpoint_directory, f"d_x_{epoch}.ckpt"))
save_checkpoint(discriminator_domain_y, os.path.join(checkpoint_directory, f"d_y_{epoch}.ckpt"))
print('Training completed!')
Model Inference
Style transfer is performed by loading trained generator parameters. Results display original images in the first row and corresponding generated images in the second row.
%%time
import os
from PIL import Image
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
from mindspore import load_checkpoint, load_param_into_net
# Load checkpoint files
def restore_weights(model, checkpoint_path):
params = load_checkpoint(checkpoint_path)
load_param_into_net(model, params)
source_ckpt = './CycleGAN_apple2orange/ckpt/g_a.ckpt'
target_ckpt = './CycleGAN_apple2orange/ckpt/g_b.ckpt'
restore_weights(translator_x_to_y, source_ckpt)
restore_weights(translator_y_to_x, target_ckpt)
# Perform inference
fig = plt.figure(figsize=(11, 2.5), dpi=100)
def inference_evaluation(data_directory, model, offset):
def load_images():
for filename in os.listdir(data_directory):
file_path = os.path.join(data_directory, filename)
image = Image.open(file_path).convert('RGB')
yield image, filename
dataset = ds.GeneratorDataset(load_images, column_names=["image", "image_name"])
transforms = [vision.Resize((256, 256)), vision.Normalize(mean=[0.5 * 255] * 3, std=[0.5 * 255] * 3), vision.HWC2CHW()]
dataset = dataset.map(operations=transforms, input_columns=["image"])
dataset = dataset.batch(1)
for i, batch in enumerate(dataset.create_dict_iterator()):
original = batch["image"]
transformed = model(original)
transformed = (transformed[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))
original = (original[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))
fig.add_subplot(2, 8, i+1+offset)
plt.axis("off")
plt.imshow(original.asnumpy())
fig.add_subplot(2, 8, i+9+offset)
plt.axis("off")
plt.imshow(transformed.asnumpy())
inference_evaluation('./CycleGAN_apple2orange/predict/apple', translator_x_to_y, 0)
inference_evaluation('./CycleGAN_apple2orange/predict/orange', translator_y_to_x, 4)
plt.show()