HCS²-Net: Unsupervised Spatial-Spectral Network for Hyperspectral Compressive Snapshot Reconstruction

Table of Contents

  1. Article Overview
  2. Framework Workflow
  3. Code Analysis
  4. Article Overview

  1. Problem Context: Hyperspectral compressive imaging utilizes compressed sensing theory to capture hyperspectral data through snapshot measurements via coded apertures, avoiding temporal scanning. The core challenge lies in reconstructing the original hyperspectral image from these compressed snapshots.

  2. Limitations of Existing Methods: Due to variations in spectral response characteristics and wavelength ranges among different imaging devices, current approaches often fail to capture complex spectral variations and lack adaptability to new imaging systems.

  3. Advantages of HCS²-Net: To address these issues, this paper introduces an unsupervised spatial-spectral network that reconstructs hyperspectral images solely from compressed snapshot measurements. This network functions as a conditional generative model, using the snapshot measurement as a condition. It employs a spatial-spectral attention module to learn spatial-spectral correlations within the hyperspectral image. Network parameters are optimized to ensure that outputs closely match the given measurements according to the imaging model, thereby improving adaptability across various imaging setups.

  4. Network Acrhitecture: The architecture consists of:

    • Input Layer: Random codes Z and snapshot measurements Y.
    • Feature Extraction: Multiple 1×1 bottleneck residual blocks (BRBs).
    • Attention Module: Spatial-spectral attention mechanism.
    • Output Layer: Reconstructed hyperspectral image.
  5. Training Process: Parameters are optimized by minimizing the error between the reconstructed image and the measured snapshot.

  6. Experimental Results: Experiments on multiple datasets show that HCS²-Net outperforms existing methods in reconstruction quality.

  7. Framework Workflow


(BN: Batch Normalization, ReLU: Rectified Linear Unit)

1. Input:

  • Random Codes Z: A randomly generated matrix used to encode the hyperspectral image.
  • Snapshot Measurement Y: Two-dimensional compressed measurement obtained from the original hyperspectral image using a coded aperture.

2. Feature Extraction: Bottleneck Residual Blocks (BRBs)

  • 1×1 Bottleneck Residual Blocks: The input is processed through several BRBs for feature extraction. These blocks reduce feature dimensions using 1×1 convolutions and combine them with skip connections to improve learning efficiency.

3. Spatial-Spectral Attention Module:

  • Attention Mechanism: This module uses attention to capture spatial-spectral correlations in hyperspectral images. It learns a 3D attention map to weight each feature point, emphasizing important features and suppressing irrelevant ones.
  • Multi-scale Fusion: The module also incorporates multi-scale fusion to combine features at different scales, further enhancing reconstruction accuracy.

4. Output:

  • A 1×1 convolution adjusts the number of channels to match the number of bands in the hyperspectral image.
  • A Sigmoid activation function constrains the output values between 0 and 1.
  • The final output is the reconstructed hyperspectral image after processing through feature extraction and attention mechanisms.

5. Network Training:

  • Loss Function: Optimization minimizes the difference between the reconstructed image and the snapshot measurement.
  • Optimizer: Adam optimizer is used to update network parameters.

Summary:

HCS²-Net takes random codes and snapshot measurements as inputs, extracts features using BRBs and a spatial-spectral attention module, and optimizes network parameters by minimizing reconstruction errors to produce a final hyperspectral reconstruction.

  1. Code Analysis

1. Spatial-Spectral Attention Module

# Spatial-spectral attention module
class spatial_att_new(nn.Module):

    def __init__(self, input_channel):
        super(spatial_att_new, self).__init__()
        self.bn = nn.BatchNorm2d(input_channel)
        self.relu = nn.ReLU(inplace=True)

        # Spatial attention weight generation
        self.conv_sa = nn.Conv2d(input_channel, input_channel, kernel_size=1, stride=1, bias=True)

        # Channel adjustment layer
        self.conv_1 = nn.Conv2d(input_channel, input_channel, kernel_size=1, stride=1, bias=True)

        # Local feature extraction
        self.conv_3 = nn.Conv2d(input_channel, input_channel, kernel_size=3, stride=1, padding=1, bias=True)

        # Downsampling
        self.con_stride = nn.Conv2d(input_channel, input_channel, 3, stride=2, padding=1)

        # Multi-scale feature fusion
        self.con = nn.Conv2d(input_channel * 2, input_channel, 3, 1, 1)

    def forward(self, x):
        x0 = self.relu(self.bn(self.conv_3(self.con_stride(x))))
        x1 = self.relu(self.bn(self.conv_1(x0)))

        x2 = self.relu(self.bn(self.conv_3(self.con_stride(x1))))
        x2 = self.relu(self.bn(self.conv_1(x2)))
        x2 = F.interpolate(x2, scale_factor=2, mode='bilinear', align_corners=False)

        x2 = self.conv_1(self.relu(self.bn(self.conv_1(self.conv_3(x2)))))
        x2_ = torch.sigmoid(x2)

        x1 = self.relu(self.bn(self.conv_3(x0)))
        x3 = x1 * x2_ * 2

        x3 = torch.cat([x3, x2], dim=1)
        x3 = self.relu(self.bn(self.con(x3)))

        x3 = self.relu(self.bn(self.conv_1(x3)))
        x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)

        x3 = self.conv_3(x3)
        x3 = self.conv_1(self.relu(self.bn(self.conv_1(x3))))
        x4 = torch.sigmoid(x3)

        if x4.shape == x.shape and x3.shape == x.shape:
            x4 = x4
            x3 = x3
        else:
            b, c, h, w = x.size()
            x4 = x4[:, :, :h, :w]
            x3 = x3[:, :, :h, :w]

        out = x * x4 * 2
        out = torch.cat([x3, out], dim=1)
        out = self.relu(self.bn(self.con(out)))

        return out

2. Network Design

2.1 Input Processing

# Extract first 64 channels as noise input
noise_input = x[:, 0:64, :, :]
# Extract 65th channel as snapshot measurement
y_input = x[:, 64:65, :, :]

# Apply convolution to random codes
noise_input = self.conv_z(noise_input)

# Apply sigmoid to snapshot measurement
y = torch.sigmoid(self.conv_y(y_input))

# Concatenate inputs
x = torch.cat([noise_input, y], dim=1)
x = self.act(self.bn(self.conv_3(x)))

2.2 Bottleneck Residual Block (BRB)

# First bottleneck residual block (BRB)
identity1 = x
x = self.act(self.bn(self.conv3_1(x)))
x = self.conv3_2(x)
x1 = x + self.conv1(identity1)
x1 = self.act(self.bn(x1))

2.3 Complete Network Structure

class hslnet(nn.Module):

    def __init__(self, in_channels, out_channels, middle_channels):
        super(hslnet, self).__init__()
        self.conv_3 = nn.Conv2d(in_channels, middle_channels, 3, 1, 1)
        self.conv_1 = nn.Conv2d(middle_channels, out_channels, 1, 1)

        self.att = spatial_att_new(middle_channels)
        self.bn = nn.BatchNorm2d(middle_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU()

        self.conv3_1 = nn.Conv2d(middle_channels, middle_channels, 3, 1, padding=1, dilation=1)
        self.conv3_2 = nn.Conv2d(middle_channels, middle_channels, 3, 1, padding=1, dilation=1)
        self.conv3_3 = nn.Conv2d(middle_channels, middle_channels, 3, 1, padding=1, dilation=1)
        self.conv3_4 = nn.Conv2d(middle_channels, middle_channels, 3, 1, padding=1, dilation=1)
        self.conv1 = nn.Conv2d(middle_channels, middle_channels, 1, 1)

        self.conv_y = nn.Conv2d(1, 1, 3, 1, 1)
        self.conv_z = nn.Conv2d(64, 64, 3, 1, 1)

    def forward(self, x):
        # Split inputs
        noise_input = x[:, 0:64, :, :]
        y_input = x[:, 64:65, :, :]

        # Process inputs
        noise_input = self.conv_z(noise_input)
        y = torch.sigmoid(self.conv_y(y_input))
        x = torch.cat([noise_input, y], dim=1)

        # Initial convolution
        x = self.act(self.bn(self.conv_3(x)))

        # BRB blocks
        identity1 = x
        x = self.act(self.bn(self.conv3_1(x)))
        x = self.conv3_2(x)
        x1 = x + self.conv1(identity1)
        x1 = self.act(self.bn(x1))

        identity2 = x1
        x = self.act(self.bn(self.conv3_3(x1)))
        x = self.conv3_4(x)
        x2 = x + self.conv1(identity2)
        x2 = self.act(self.bn(x2))

        identity3 = x2
        x = self.act(self.bn(self.conv3_3(x2)))
        x = self.conv3_2(x)
        x3 = x + self.conv1(identity3)
        x3 = self.act(self.bn(x3))

        # Apply attention module
        x3 = self.att(x3)

        # Final output
        x = self.conv_1(x3)
        x = torch.sigmoid(x)

        return x

2.4 Testing

# Create model instance
model = hslnet(in_channels=65, out_channels=1, middle_channels=32)

# Define test input
input_tensor = torch.randn(1, 65, 512, 512)

# Run model
output = model(input_tensor)
print("Output shape:", output.shape)

Output result:

Tags: hyperspectral imaging compressed sensing Deep Learning attention mechanism Neural Networks

Posted on Sun, 10 May 2026 04:30:41 +0000 by kkobashi