CycleGAN Architecture Overview
CycleGAN enables unpaired image-to-image translation using cycle-consistent adversarial networks. This approach learns mappings between domains without requiring paired training examples, making it suitable for style transfer applications like converting apples to oranges.
Dataset Preparation
The dataset consists of 996 apple and 1020 orange training images, with 266 apple and 248 orange test images. All image are resized to 256×256 pixels and preprocessed with random cropping, horziontal flipping, and normalization.
from mindspore.dataset import MindDataset
dataset = MindDataset(dataset_files="apple2orange_train.mindrecord")
batch_size = 1
data_loader = dataset.batch(batch_size)
dataset_size = data_loader.get_dataset_size()
Generator Implementation
The generator uses a ResNet-based architecture with 9 residual blocks for 256×256 images. It employs convolutional layers with normalization and ReLU activation:
import mindspore.nn as nn
class ConvBlock(nn.Cell):
def __init__(self, in_channels, out_channels, kernel=4, stride=2,
norm='instance', transpose=False):
super().__init__()
padding = (kernel - 1) // 2
if transpose:
conv = nn.Conv2dTranspose(in_channels, out_channels, kernel,
stride, padding=padding)
else:
conv = nn.Conv2d(in_channels, out_channels, kernel,
stride, padding=padding)
layers = [conv]
if norm == 'instance':
layers.append(nn.BatchNorm2d(out_channels, affine=False))
layers.append(nn.ReLU() if stride == 1 else nn.LeakyReLU(0.2))
self.model = nn.SequentialCell(layers)
class ResidualModule(nn.Cell):
def __init__(self, channels):
super().__init__()
self.block = nn.SequentialCell([
ConvBlock(channels, channels, kernel=3, stride=1),
ConvBlock(channels, channels, kernel=3, stride=1, use_relu=False)
])
def construct(self, x):
return x + self.block(x)
class ImageTranslator(nn.Cell):
def __init__(self, channels=64, blocks=9):
super().__init__()
self.initial = ConvBlock(3, channels, kernel=7, stride=1)
self.downsample = nn.SequentialCell([
ConvBlock(channels, channels*2),
ConvBlock(channels*2, channels*4)
])
self.res_blocks = nn.SequentialCell(
[ResidualModule(channels*4) for _ in range(blocks)]
)
self.upsample = nn.SequentialCell([
ConvBlock(channels*4, channels*2, transpose=True),
ConvBlock(channels*2, channels, transpose=True)
])
self.final = nn.Conv2d(channels, 3, kernel=7, padding=3)
def construct(self, x):
x = self.initial(x)
x = self.downsample(x)
x = self.res_blocks(x)
x = self.upsample(x)
return nn.Tanh()(self.final(x))
Discriminattor Implementation
The PatchGAN discriminator uses convolutional layers to classify image patches:
class Discriminator(nn.Cell):
def __init__(self, in_channels=3, base_channels=64, layers=3):
super().__init__()
sequence = [
nn.Conv2d(in_channels, base_channels, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2)
]
current_channels = base_channels
for i in range(1, layers):
next_channels = min(2**i, 8) * base_channels
sequence.append(ConvBlock(current_channels, next_channels))
current_channels = next_channels
sequence.append(nn.Conv2d(current_channels, 1, kernel_size=4, padding=1))
self.model = nn.SequentialCell(sequence)
def construct(self, x):
return self.model(x)
Training Setup
The training uses adversarial loss combined with cycle consistency loss:
import mindspore.ops as ops
gen_apple = ImageTranslator()
gen_orange = ImageTranslator()
disc_apple = Discriminator()
disc_orange = Discriminator()
gen_optim = nn.Adam([*gen_apple.trainable_params(), *gen_orange.trainable_params()],
learning_rate=0.0002)
disc_optim = nn.Adam([*disc_apple.trainable_params(), *disc_orange.trainable_params()],
learning_rate=0.0002)
def adversarial_loss(pred, target):
target_tensor = ops.ones_like(pred) * target
return nn.MSELoss()(pred, target_tensor)
def cycle_consistency_loss(original, reconstructed):
return nn.L1Loss()(original, reconstructed) * 10.0
Training Process
The training alternates between generator and discriminator updates:
for epoch in range(total_epochs):
for images in data_loader:
apple_img, orange_img = images["image_A"], images["image_B"]
# Generator forward
fake_orange = gen_orange(apple_img)
fake_apple = gen_apple(orange_img)
# Cycle consistency
cycled_apple = gen_orange(fake_orange)
cycled_orange = gen_apple(fake_apple)
# Identity preservation
id_apple = gen_orange(apple_img)
id_orange = gen_apple(orange_img)
# Generator losses
gen_loss_orange = adversarial_loss(disc_orange(fake_orange), True)
gen_loss_apple = adversarial_loss(disc_apple(fake_apple), True)
cycle_loss_apple = cycle_consistency_loss(apple_img, cycled_apple)
cycle_loss_orange = cycle_consistency_loss(orange_img, cycled_orange)
id_loss_apple = cycle_consistency_loss(apple_img, id_apple) * 0.5
id_loss_orange = cycle_consistency_loss(orange_img, id_orange) * 0.5
total_gen_loss = (gen_loss_orange + gen_loss_apple +
cycle_loss_apple + cycle_loss_orange +
id_loss_apple + id_loss_orange)
# Update generators
gen_optim.clear_grad()
total_gen_loss.backward()
gen_optim.step()
# Discriminator losses
disc_loss_apple = adversarial_loss(disc_apple(apple_img), True) + \
adversarial_loss(disc_apple(fake_apple.detach()), False)
disc_loss_orange = adversarial_loss(disc_orange(orange_img), True) + \
adversarial_loss(disc_orange(fake_orange.detach()), False)
total_disc_loss = (disc_loss_apple + disc_loss_orange) * 0.5
# Update discriminators
disc_optim.clear_grad()
total_disc_loss.backward()
disc_optim.step()
Inference Implementation
After training, we can apply style transfer to new images:
def apply_style_transfer(model, image_path):
img = Image.open(image_path).convert('RGB')
transform = nn.SequentialCell([
vision.Resize((256, 256)),
vision.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
vision.HWC2CHW()
])
processed = transform(img)
output = model(processed.expand_dims(0))
return (output.squeeze().transpose(1, 2, 0) * 127.5 + 127.5).astype(np.uint8)