ResNet50 Implementation for CIFAR-10 Image Classification

Image Classification Fundamentals

Image classification represents a foundational computer vision task within supervised learning paradigms. Given input imagery (e.g., cats, vehicles, aircraft), the objective is too assign the correct category label. This implementation demonstrates ResNet50 architecture applied to the CIFAR-10 dataset for classification purposes.

ResNet Architecture Overview

Introduced by He et al. in 2015, ResNet50 achieved top performance in the ILSVRC2015 competition. Traditional deep CNNs suffered from performance degradation as depth increased, evidenced by higher training and validation errors in deeper networks (e.g., 56-layer vs 20-layer on CIFAR-10). ResNet addresses this through residual connections that enable deeper network construction (beyond 1000 layers) while improving accuracy.

Data Preparation and Processing

CIFAR-10 comprises 60,000 32×32 color images across 10 classes (6,000 per class). The following pipeline loads and preprocesses the dataset:

import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision

def load_cifar10_data(data_root, split, img_size, batch_size, workers):
    dataset = ds.Cifar10Dataset(
        dataset_dir=data_root,
        usage=split,
        num_parallel_workers=workers,
        shuffle=True
    )
    
    transform_ops = []
    if split == "train":
        transform_ops.extend([
            vision.RandomCrop((32, 32), (4, 4, 4, 4)),
            vision.RandomHorizontalFlip(0.5)
        ])
    
    transform_ops.extend([
        vision.Resize(img_size),
        vision.Rescale(1.0 / 255.0, 0.0),
        vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
        vision.HWC2CHW()
    ])
    
    label_transform = ds.transforms.TypeCast(ms.int32)
    
    dataset = dataset.map(operations=transform_ops, input_columns="image", num_parallel_workers=workers)
    dataset = dataset.map(operations=label_transform, input_columns="label", num_parallel_workers=workers)
    return dataset.batch(batch_size)

Residual Block Implementation

ResNet employs two residual block types:

Basic Block (for shallow networks)

Comprises two 3×3 convolution layers with batch normalization and ReLU activation. Handles channel matching via identity mapping when dimensions change:

class BasicBlock(ms.nn.Cell):
    expansion = 1
    
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = ms.nn.Conv2d(in_channels, out_channels, 3, stride, pad_mode='same')
        self.bn1 = ms.nn.BatchNorm2d(out_channels)
        self.conv2 = ms.nn.Conv2d(out_channels, out_channels, 3, pad_mode='same')
        self.bn2 = ms.nn.BatchNorm2d(out_channels)
        self.relu = ms.nn.ReLU()
        self.downsample = downsample
    
    def construct(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return self.relu(out)

Bottleneck Block (for deep networks)

Uses 1×1 → 3×3 → 1×1 convolution sequence for computational efficiency. Handles dimension mismatches via channel expansion and stride adjustment:

class BottleneckBlock(ms.nn.Cell):
    expansion = 4
    
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BottleneckBlock, self).__init__()
        self.conv1 = ms.nn.Conv2d(in_channels, out_channels, 1)
        self.bn1 = ms.nn.BatchNorm2d(out_channels)
        self.conv2 = ms.nn.Conv2d(out_channels, out_channels, 3, stride, pad_mode='same')
        self.bn2 = ms.nn.BatchNorm2d(out_channels)
        self.conv3 = ms.nn.Conv2d(out_channels, out_channels * self.expansion, 1)
        self.bn3 = ms.nn.BatchNorm2d(out_channels * self.expansion)
        self.relu = ms.nn.ReLU()
        self.downsample = downsample
    
    def construct(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return self.relu(out)

ResNet50 Network Construction

The ResNet50 acrhitecture consists of:

  1. Initial convolution (7×7, stride=2)
  2. Max pooling (3×3, stride=2)
  3. Four residual blocks (conv2_x to conv5_x)
  4. Global average pooling and classification head
def build_resnet50(num_classes=10):
    model = ms.nn.SequentialCell([
        ms.nn.Conv2d(3, 64, 7, 2, pad_mode='same'),
        ms.nn.BatchNorm2d(64),
        ms.nn.ReLU(),
        ms.nn.MaxPool2d(3, 2, pad_mode='same'),
        build_block(BottleneckBlock, 64, 3, 1),
        build_block(BottleneckBlock, 128, 4, 2),
        build_block(BottleneckBlock, 256, 6, 2),
        build_block(BottleneckBlock, 512, 3, 2),
        ms.nn.AdaptiveAvgPool2d(1),
        ms.nn.Flatten(),
        ms.nn.Dense(2048, num_classes)
    ])
    return model

def build_block(block, channels, num_blocks, stride):
    downsample = None
    if stride != 1 or channels != block.expansion * channels:
        downsample = ms.nn.SequentialCell([
            ms.nn.Conv2d(channels, block.expansion * channels, 1, stride),
            ms.nn.BatchNorm2d(block.expansion * channels)
        ])
    layers = []
    layers.append(block(channels, channels, stride, downsample))
    for _ in range(1, num_blocks):
        layers.append(block(channels, channels))
    return ms.nn.SequentialCell(layers)

Trianing and Evaluation Pipeline

Pre-trained ResNet50 weights are fine-tuned for CIFAR-10 (10 classes). The classifier head is reconfigured to match the target class count:

model = build_resnet50(num_classes=10)
pretrained_weights = ms.load_checkpoint("resnet50_224_new.ckpt")
ms.load_param_into_net(model, pretrained_weights)

# Adjust classifier head
model.fc = ms.nn.Dense(2048, 10)

# Training configuration
optimizer = ms.nn.Momentum(model.trainable_params(), 0.001, 0.9)
loss_fn = ms.nn.SoftmaxCrossEntropyWithLogits(sparse=True)

def train_step(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    grads = ms.grad(model, loss)(data, label)
    optimizer(grads)
    return loss

# Training loop (5 epochs example)
for epoch in range(5):
    for batch_idx, (images, labels) in enumerate(train_dataset):
        loss = train_step(images, labels)
        if batch_idx % 100 == 0:
            print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")

Tags: mindspore resnet50 cifar10 image-classification residual-block

Posted on Sat, 16 May 2026 14:01:15 +0000 by soulrazer