CycleGAN Implementation for Unpaired Image Translation

CycleGAN Architecture Overview

CycleGAN enables unpaired image-to-image translation using cycle-consistent adversarial networks. This approach learns mappings between domains without requiring paired training examples, making it suitable for style transfer applications like converting apples to oranges.

Dataset Preparation

The dataset consists of 996 apple and 1020 orange training images, with 266 apple and 248 orange test images. All image are resized to 256×256 pixels and preprocessed with random cropping, horziontal flipping, and normalization.

from mindspore.dataset import MindDataset

dataset = MindDataset(dataset_files="apple2orange_train.mindrecord")
batch_size = 1
data_loader = dataset.batch(batch_size)
dataset_size = data_loader.get_dataset_size()

Generator Implementation

The generator uses a ResNet-based architecture with 9 residual blocks for 256×256 images. It employs convolutional layers with normalization and ReLU activation:

import mindspore.nn as nn

class ConvBlock(nn.Cell):
    def __init__(self, in_channels, out_channels, kernel=4, stride=2, 
                 norm='instance', transpose=False):
        super().__init__()
        padding = (kernel - 1) // 2
        if transpose:
            conv = nn.Conv2dTranspose(in_channels, out_channels, kernel, 
                                     stride, padding=padding)
        else:
            conv = nn.Conv2d(in_channels, out_channels, kernel, 
                            stride, padding=padding)
        
        layers = [conv]
        if norm == 'instance':
            layers.append(nn.BatchNorm2d(out_channels, affine=False))
        
        layers.append(nn.ReLU() if stride == 1 else nn.LeakyReLU(0.2))
        self.model = nn.SequentialCell(layers)

class ResidualModule(nn.Cell):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.SequentialCell([
            ConvBlock(channels, channels, kernel=3, stride=1),
            ConvBlock(channels, channels, kernel=3, stride=1, use_relu=False)
        ])
    
    def construct(self, x):
        return x + self.block(x)

class ImageTranslator(nn.Cell):
    def __init__(self, channels=64, blocks=9):
        super().__init__()
        self.initial = ConvBlock(3, channels, kernel=7, stride=1)
        self.downsample = nn.SequentialCell([
            ConvBlock(channels, channels*2),
            ConvBlock(channels*2, channels*4)
        ])
        self.res_blocks = nn.SequentialCell(
            [ResidualModule(channels*4) for _ in range(blocks)]
        )
        self.upsample = nn.SequentialCell([
            ConvBlock(channels*4, channels*2, transpose=True),
            ConvBlock(channels*2, channels, transpose=True)
        ])
        self.final = nn.Conv2d(channels, 3, kernel=7, padding=3)
    
    def construct(self, x):
        x = self.initial(x)
        x = self.downsample(x)
        x = self.res_blocks(x)
        x = self.upsample(x)
        return nn.Tanh()(self.final(x))

Discriminattor Implementation

The PatchGAN discriminator uses convolutional layers to classify image patches:

class Discriminator(nn.Cell):
    def __init__(self, in_channels=3, base_channels=64, layers=3):
        super().__init__()
        sequence = [
            nn.Conv2d(in_channels, base_channels, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2)
        ]
        
        current_channels = base_channels
        for i in range(1, layers):
            next_channels = min(2**i, 8) * base_channels
            sequence.append(ConvBlock(current_channels, next_channels))
            current_channels = next_channels
        
        sequence.append(nn.Conv2d(current_channels, 1, kernel_size=4, padding=1))
        self.model = nn.SequentialCell(sequence)
    
    def construct(self, x):
        return self.model(x)

Training Setup

The training uses adversarial loss combined with cycle consistency loss:

import mindspore.ops as ops

gen_apple = ImageTranslator()
gen_orange = ImageTranslator()
disc_apple = Discriminator()
disc_orange = Discriminator()

gen_optim = nn.Adam([*gen_apple.trainable_params(), *gen_orange.trainable_params()], 
                   learning_rate=0.0002)
disc_optim = nn.Adam([*disc_apple.trainable_params(), *disc_orange.trainable_params()], 
                    learning_rate=0.0002)

def adversarial_loss(pred, target):
    target_tensor = ops.ones_like(pred) * target
    return nn.MSELoss()(pred, target_tensor)

def cycle_consistency_loss(original, reconstructed):
    return nn.L1Loss()(original, reconstructed) * 10.0

Training Process

The training alternates between generator and discriminator updates:

for epoch in range(total_epochs):
    for images in data_loader:
        apple_img, orange_img = images["image_A"], images["image_B"]
        
        # Generator forward
        fake_orange = gen_orange(apple_img)
        fake_apple = gen_apple(orange_img)
        
        # Cycle consistency
        cycled_apple = gen_orange(fake_orange)
        cycled_orange = gen_apple(fake_apple)
        
        # Identity preservation
        id_apple = gen_orange(apple_img)
        id_orange = gen_apple(orange_img)
        
        # Generator losses
        gen_loss_orange = adversarial_loss(disc_orange(fake_orange), True)
        gen_loss_apple = adversarial_loss(disc_apple(fake_apple), True)
        cycle_loss_apple = cycle_consistency_loss(apple_img, cycled_apple)
        cycle_loss_orange = cycle_consistency_loss(orange_img, cycled_orange)
        id_loss_apple = cycle_consistency_loss(apple_img, id_apple) * 0.5
        id_loss_orange = cycle_consistency_loss(orange_img, id_orange) * 0.5
        
        total_gen_loss = (gen_loss_orange + gen_loss_apple + 
                         cycle_loss_apple + cycle_loss_orange + 
                         id_loss_apple + id_loss_orange)
        
        # Update generators
        gen_optim.clear_grad()
        total_gen_loss.backward()
        gen_optim.step()
        
        # Discriminator losses
        disc_loss_apple = adversarial_loss(disc_apple(apple_img), True) + \
                          adversarial_loss(disc_apple(fake_apple.detach()), False)
        disc_loss_orange = adversarial_loss(disc_orange(orange_img), True) + \
                           adversarial_loss(disc_orange(fake_orange.detach()), False)
        
        total_disc_loss = (disc_loss_apple + disc_loss_orange) * 0.5
        
        # Update discriminators
        disc_optim.clear_grad()
        total_disc_loss.backward()
        disc_optim.step()

Inference Implementation

After training, we can apply style transfer to new images:

def apply_style_transfer(model, image_path):
    img = Image.open(image_path).convert('RGB')
    transform = nn.SequentialCell([
        vision.Resize((256, 256)),
        vision.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        vision.HWC2CHW()
    ])
    processed = transform(img)
    output = model(processed.expand_dims(0))
    return (output.squeeze().transpose(1, 2, 0) * 127.5 + 127.5).astype(np.uint8)

Tags: CycleGAN mindspore Image Translation Generative Adversarial Networks Computer Vision

Posted on Sat, 20 Jun 2026 17:38:35 +0000 by mrjam