Implementing ResNet50 for Image Classification on CIFAR-10 with MindSpore

Image classification, a fundamental computer vision task, falls under supervised learning. Given an image, the goal is to predict its category. This article demonstrates how too use a ResNet50 network to classify the CIFAR-10 dataset using the MindSpore framework.

ResNet Architecture

ResNet50, introduced by Kaiming He et al. in 2015, won the ILSVRC2015 image classification competition. Traditional deep convolutional networks suffer from degradation as depth increases—deeper networks yield higher training and test errors. ResNet introduces residual blocks to mitigate this issue, enabling networks to exceeed 1000 layers while improving accuracy.

Dataset Preparation and Loading

CIFAR-10 contains 60,000 32×32 color images across 10 classes (6,000 per class), split into 50,000 training and 10,000 test images. The dataset is downloaded in binary format.

from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz"
download(url, "./cifar10", kind="tar.gz", replace=True)

Directory structure:

cifar10/cifar-10-batches-bin
├── batches.meta.text
├── data_batch_1.bin
├── data_batch_2.bin
├── data_batch_3.bin
├── data_batch_4.bin
├── data_batch_5.bin
├── readme.html
└── test_batch.bin

Define a function to load and augment the dataset:

import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
import mindspore.dataset.transforms as transforms
from mindspore import dtype as mstype

data_dir = "./cifar10/cifar-10-batches-bin"
batch_size = 256
image_size = 32
num_workers = 4
num_classes = 10

def create_cifar_dataset(dataset_dir, usage, resize, batch_size, workers):
    dataset = ds.Cifar10Dataset(dataset_dir=dataset_dir, usage=usage, num_parallel_workers=workers, shuffle=True)
    transforms_list = []
    if usage == "train":
        transforms_list += [
            vision.RandomCrop((32, 32), (4, 4, 4, 4)),
            vision.RandomHorizontalFlip(prob=0.5)
        ]
    transforms_list += [
        vision.Resize(resize),
        vision.Rescale(1.0 / 255.0, 0.0),
        vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
        vision.HWC2CHW()
    ]
    target_transform = transforms.TypeCast(mstype.int32)
    dataset = dataset.map(operations=transforms_list, input_columns='image', num_parallel_workers=workers)
    dataset = dataset.map(operations=target_transform, input_columns='label', num_parallel_workers=workers)
    dataset = dataset.batch(batch_size)
    return dataset

train_dataset = create_cifar_dataset(dataset_dir=data_dir, usage="train", resize=image_size, batch_size=batch_size, workers=num_workers)
val_dataset = create_cifar_dataset(dataset_dir=data_dir, usage="test", resize=image_size, batch_size=batch_size, workers=num_workers)
train_steps = train_dataset.get_dataset_size()
val_steps = val_dataset.get_dataset_size()

Visualize some training samples (code omitted for brevity, but similar to original).

Building the Network

Residual Blocks

Two types of residual blocks are used: BasicBlock for shallow networks (e.g., ResNet18) and Bottleneck for deeper networks (e.g., ResNet50). We implement Bottleneck for ResNet50.

from typing import Type, List, Optional
import mindspore.nn as nn
from mindspore.common.initializer import Normal

weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)

class BottleneckBlock(nn.Cell):
    expansion = 4

    def __init__(self, in_channels: int, out_channels: int, stride: int = 1, down_sample: Optional[nn.Cell] = None):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, weight_init=weight_init)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, weight_init=weight_init)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, weight_init=weight_init)
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
        self.relu = nn.ReLU()
        self.down_sample = down_sample

    def construct(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.down_sample is not None:
            identity = self.down_sample(x)
        out += identity
        out = self.relu(out)
        return out

Helper function to create a stage of residual blocks:

def make_stage(last_out_channels: int, block: Type[Union[BottleneckBlock]], channels: int, num_blocks: int, stride: int = 1):
    down_sample = None
    if stride != 1 or last_out_channels != channels * block.expansion:
        down_sample = nn.SequentialCell([
            nn.Conv2d(last_out_channels, channels * block.expansion, kernel_size=1, stride=stride, weight_init=weight_init),
            nn.BatchNorm2d(channels * block.expansion, gamma_init=gamma_init)
        ])
    layers = []
    layers.append(block(last_out_channels, channels, stride=stride, down_sample=down_sample))
    in_ch = channels * block.expansion
    for _ in range(1, num_blocks):
        layers.append(block(in_ch, channels))
    return nn.SequentialCell(layers)

ResNet50 Model

ResNet50 consists of a stem convolution, four stages of Bottleneck blocks, an average pooling layer, and a fully connected layer.

class ResNetModel(nn.Cell):
    def __init__(self, block: Type[Union[BottleneckBlock]], layers: List[int], num_classes: int, input_channels: int):
        super().__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        self.layer1 = make_stage(64, block, 64, layers[0])
        self.layer2 = make_stage(64 * block.expansion, block, 128, layers[1], stride=2)
        self.layer3 = make_stage(128 * block.expansion, block, 256, layers[2], stride=2)
        self.layer4 = make_stage(256 * block.expansion, block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(kernel_size=1)
        self.flatten = nn.Flatten()
        self.fc = nn.Dense(in_channels=input_channels, out_channels=num_classes)

    def construct(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

def resnet50(num_classes: int = 1000, pretrained: bool = False):
    model = ResNetModel(BottleneckBlock, [3, 4, 6, 3], num_classes, 2048)
    if pretrained:
        url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"
        ckpt_path = "./pretrained/resnet50.ckpt"
        download(url, ckpt_path, replace=True)
        param_dict = ms.load_checkpoint(ckpt_path)
        ms.load_param_into_net(model, param_dict)
    return model

Training and Evaluation

We fine-tune a pretrained ResNet50 (originally for 1000 classes) on CIFAR-10 (10 classes). Replace the final fully connected layer.

network = resnet50(pretrained=True)
in_features = network.fc.in_channels
network.fc = nn.Dense(in_features, 10)

Set up learning rate, optimizer, and loss:

num_epochs = 5
lr = nn.cosine_decay_lr(min_lr=1e-5, max_lr=1e-3, total_step=train_steps * num_epochs, step_per_epoch=train_steps, decay_epoch=num_epochs)
optimizer = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=0.9)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

def forward_fn(inputs, targets):
    logits = network(inputs)
    loss = loss_fn(logits, targets)
    return loss

grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)

def train_step(inputs, targets):
    loss, grads = grad_fn(inputs, targets)
    optimizer(grads)
    return loss

Training loop with evaluation:

import os
best_acc = 0
best_ckpt_path = "./best_checkpoint/resnet50-best.ckpt"
os.makedirs("./best_checkpoint", exist_ok=True)

def train_epoch(dataset):
    network.set_train(True)
    losses = []
    for i, (images, labels) in enumerate(dataset):
        loss = train_step(images, labels)
        if i % 100 == 0 or i == train_steps - 1:
            print(f'Epoch: [{epoch+1}/{num_epochs}], Step: [{i+1}/{train_steps}], Loss: {loss:.3f}')
        losses.append(loss)
    return sum(losses) / len(losses)

def evaluate(dataset):
    network.set_train(False)
    correct = 0
    total = 0
    for images, labels in dataset:
        logits = network(images)
        preds = logits.argmax(axis=1)
        correct += (preds == labels).sum().asnumpy()
        total += labels.shape[0]
    return correct / total

print("Start training...")
for epoch in range(num_epochs):
    train_loss = train_epoch(train_dataset.create_tuple_iterator(num_epochs=1))
    val_acc = evaluate(val_dataset.create_tuple_iterator(num_epochs=1))
    print(f"Epoch [{epoch+1}/{num_epochs}] - Avg Train Loss: {train_loss:.3f}, Accuracy: {val_acc:.3f}")
    if val_acc > best_acc:
        best_acc = val_acc
        ms.save_checkpoint(network, best_ckpt_path)
print(f"Training complete. Best accuracy: {best_acc:.3f}. Model saved to {best_ckpt_path}")

Visualizing Predictions

Load the best model and display predictions on test images (blue for correct, red for incorrect).

import matplotlib.pyplot as plt
import numpy as np

def visualize_predictions(ckpt_path, dataset):
    net = resnet50(num_classes=10)
    param_dict = ms.load_checkpoint(ckpt_path)
    ms.load_param_into_net(net, param_dict)
    net.set_train(False)
    data = next(dataset.create_dict_iterator())
    images = data['image']
    labels = data['label']
    outputs = net(images)
    preds = np.argmax(outputs.asnumpy(), axis=1)
    with open(data_dir + "/batches.meta.txt", "r") as f:
        classes = [line.strip() for line in f if line.strip()]
    plt.figure(figsize=(12, 8))
    for i in range(6):
        plt.subplot(2, 3, i+1)
        color = 'blue' if preds[i] == labels.asnumpy()[i] else 'red'
        plt.title(f"Predict: {classes[preds[i]]}", color=color)
        img = np.transpose(images.asnumpy()[i], (1, 2, 0))
        mean = np.array([0.4914, 0.4822, 0.4465])
        std = np.array([0.2023, 0.1994, 0.2010])
        img = std * img + mean
        img = np.clip(img, 0, 1)
        plt.imshow(img)
        plt.axis('off')
    plt.show()

visualize_predictions(best_ckpt_path, val_dataset)

After 5 epochs, the model achieves around 73% accuracy. For optimal results, train for 80 epochs or more.

Tags: resnet50 image classification mindspore CIFAR-10

Posted on Tue, 26 May 2026 17:23:47 +0000 by tempi