Understanding Denoising Diffusion Probabilistic Models
Forward Diffusion Process
The forward diffusion process in Denoising Diffusion Probabilistic Models (DDPMs) is a fundamental component that gradually transforms clean data into noise over a series of steps. This process is mathematically defined as a Markov chain where each step adds a small amount of Gaussian noise to the data.
Noise Addition Mechanism
To add noise to an original image x0, we use a specific formulation that controls the amount of noise added at each timestep t. The noise addition process follows the equation:
xt = √ᾱt · x0 + √(1 - ᾱt) · ε
where ε ~ N(0,1) is random noise sampled from a standard normal distribution, and ᾱt represents the cumulative product of alpha values up to timestep t.
Implementation of Noise Addition
The following code demonstrates how to implement the noise addition process:
class DiffusionSampler:
def __init__(self, random_generator, training_steps=1000, beta_start=0.00085, beta_end=0.0120):
"""
Initialize the diffusion sampler with parameters for the noise schedule.
Args:
random_generator: PyTorch random number generator
training_steps: Number of steps in the training process
beta_start: Starting value of the beta schedule
beta_end: Ending value of the beta schedule
"""
# Initialize beta values linearly in sqrt space
self.beta_values = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, training_steps, dtype=torch.float32) ** 2
# Calculate alpha values as 1 - beta
self.alpha_values = 1.0 - self.beta_values
# Compute cumulative product of alphas
self.alpha_cumprod = torch.cumprod(self.alpha_values, dim=0)
# Store other parameters
self.unit = torch.tensor(1.0)
self.random_generator = random_generator
self.total_training_steps = training_steps
# Create timesteps in reverse order for sampling
self.timesteps = torch.from_numpy(np.arange(0, training_steps)[::-1].copy())
def set_sampling_steps(self, num_steps=50):
"""
Set the number of steps to use during the sampling/inference process.
Args:
num_steps: Number of steps to use during inference
"""
self.num_sampling_steps = num_steps
# Calculate the ratio between training and sampling steps
step_ratio = self.total_training_steps // self.num_sampling_steps
# Generate timesteps for inference
sampling_timesteps = (np.arange(0, num_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
self.timesteps = torch.from_numpy(sampling_timesteps)
def add_noise(self, original_data, time_steps):
"""
Add noise to the original data according to the diffusion process.
Args:
original_data: The original data (e.g., images) to which noise will be added
time_steps: The timesteps at which the noise will be added
Returns:
The noisy data
"""
# Move alpha_cumprod to the same device and dtype as original_data
alpha_cumprod = self.alpha_cumprod.to(device=original_data.device, dtype=original_data.dtype)
# Move time_steps to the same device as original_data
time_steps = time_steps.to(original_data.device)
# Calculate sqrt of cumulative alphas for the given timesteps
sqrt_alpha_prod = alpha_cumprod[time_steps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
# Reshape to match dimensions of original_data
while len(sqrt_alpha_prod.shape) < len(original_data.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
# Calculate sqrt of (1 - cumulative alphas) for the given timesteps
sqrt_one_minus_alpha_prod = (1 - alpha_cumprod[time_steps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
# Reshape to match dimensions of original_data
while len(sqrt_one_minus_alpha_prod.shape) < len(original_data.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
# Sample random noise
random_noise = torch.randn(original_data.shape, generator=self.random_generator,
device=original_data.device, dtype=original_data.dtype)
# Apply the noise addition formula
noisy_data = sqrt_alpha_prod * original_data + sqrt_one_minus_alpha_prod * random_noise
return noisy_data
Distribution Parameters Calculation
During the diffusion process, we need to calculate parameters for the distribution q(xt-1 | xt, x0). This distribution is defined as a normal distribution with mean μ̃t and variance β̃tI.
Variance Calculation
The following method calculates the variance for a given timestep:
def compute_variance(self, timestep):
"""
Calculate the variance for the given timestep during the diffusion process.
Args:
timestep: The current timestep
Returns:
The variance value
"""
# Get the previous timestep
prev_timestep = self.get_previous_timestep(timestep)
# Get cumulative alphas for current and previous timesteps
current_alpha_cumprod = self.alpha_cumprod[timestep]
prev_alpha_cumprod = self.alpha_cumprod[prev_timestep] if prev_timestep >= 0 else self.unit
# Calculate current beta
current_beta = 1 - current_alpha_cumprod / prev_alpha_cumprod
# Compute variance using the formula from the DDPM paper
variance = (1 - prev_alpha_cumprod) / (1 - current_alpha_cumprod) * current_beta
# Ensure variance is not too small
variance = torch.clamp(variance, min=1e-20)
return variance
Mean Calculation and Update
The mean calculation involves predicting the original sample from the noisy sample and then computing the mean for the previous timestep. The mean is then updated with additional noise:
def diffusion_step(self, timestep, latent_data, model_output):
"""
Perform one step of the diffusion process.
Args:
timestep: Current timestep
latent_data: Latent representation of the data
model_output: Output from the diffusion model
Returns:
The previous sample in the diffusion process
"""
t = timestep
prev_t = self.get_previous_timestep(t)
# Calculate alpha and beta values
alpha_prod_t = self.alpha_cumprod[t]
alpha_prod_t_prev = self.alpha_cumprod[prev_t] if prev_t >= 0 else self.unit
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
# Calculate current alpha and beta
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
current_beta_t = 1 - current_alpha_t
# Predict original sample from model output
predicted_original_sample = (latent_data - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
# Calculate coefficients for the mean calculation
original_sample_coeff = (alpha_prod_t_prev ** 0.5 * current_beta_t) / beta_prod_t
current_sample_coeff = current_alpha_t ** 0.5 * beta_prod_t_prev / beta_prod_t
# Compute predicted previous sample (mean)
predicted_prev_sample = original_sample_coeff * predicted_original_sample + current_sample_coeff * latent_data
# Add variance to the mean
if t > 0:
# Generate random noise
noise = torch.randn(model_output.shape, generator=self.random_generator,
device=model_output.device, dtype=model_output.dtype)
# Compute variance and add it to the predicted sample
variance = (self.compute_variance(t) ** 0.5) * noise
predicted_prev_sample = predicted_prev_sample + variance
return predicted_prev_sample
def get_previous_timestep(self, timestep):
"""
Calculate the previous timestep for the given timestep.
Args:
timestep: Current timestep
Returns:
The previous timestep
"""
# Calculate previous timestep by subtracting the step ratio
prev_t = timestep - self.total_training_steps // self.num_sampling_steps
return prev_t
Setting Noise Strength
The noise strength parameter controls how much noise is added to the input image, which affects the final output:
def set_noise_strength(self, strength=1.0):
"""
Set the strength of noise to add to the input image.
Args:
strength: A value between 0 and 1 indicating the amount of noise
"""
# Calculate the starting step based on strength
start_step = self.num_sampling_steps - int(self.num_sampling_steps * strength)
# Update timesteps to start from the calculated step
self.timesteps = self.timesteps[start_step:]
self.start_step = start_step
Complete Implementation
The complete implementation of the diffusion sampler combines all the above components:
import torch
import numpy as np
class DiffusionSampler:
"""
A complete implementation of a diffusion sampler for Denoising Diffusion Probabilistic Models.
This class handles the forward diffusion process, including noise addition, variance calculation,
and the step-by-step transformation of data into noise and back.
"""
def __init__(self, random_generator, training_steps=1000, beta_start=0.00085, beta_end=0.0120):
"""
Initialize the diffusion sampler with parameters for the noise schedule.
Args:
random_generator: PyTorch random number generator
training_steps: Number of steps in the training process
beta_start: Starting value of the beta schedule
beta_end: Ending value of the beta schedule
"""
# Initialize beta values linearly in sqrt space
self.beta_values = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, training_steps, dtype=torch.float32) ** 2
# Calculate alpha values as 1 - beta
self.alpha_values = 1.0 - self.beta_values
# Compute cumulative product of alphas
self.alpha_cumprod = torch.cumprod(self.alpha_values, dim=0)
# Store other parameters
self.unit = torch.tensor(1.0)
self.random_generator = random_generator
self.total_training_steps = training_steps
# Create timesteps in reverse order for sampling
self.timesteps = torch.from_numpy(np.arange(0, training_steps)[::-1].copy())
def set_sampling_steps(self, num_steps=50):
"""
Set the number of steps to use during the sampling/inference process.
Args:
num_steps: Number of steps to use during inference
"""
self.num_sampling_steps = num_steps
# Calculate the ratio between training and sampling steps
step_ratio = self.total_training_steps // self.num_sampling_steps
# Generate timesteps for inference
sampling_timesteps = (np.arange(0, num_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
self.timesteps = torch.from_numpy(sampling_timesteps)
def get_previous_timestep(self, timestep):
"""
Calculate the previous timestep for the given timestep.
Args:
timestep: Current timestep
Returns:
The previous timestep
"""
# Calculate previous timestep by subtracting the step ratio
prev_t = timestep - self.total_training_steps // self.num_sampling_steps
return prev_t
def compute_variance(self, timestep):
"""
Calculate the variance for the given timestep during the diffusion process.
Args:
timestep: The current timestep
Returns:
The variance value
"""
# Get the previous timestep
prev_timestep = self.get_previous_timestep(timestep)
# Get cumulative alphas for current and previous timesteps
current_alpha_cumprod = self.alpha_cumprod[timestep]
prev_alpha_cumprod = self.alpha_cumprod[prev_timestep] if prev_timestep >= 0 else self.unit
# Calculate current beta
current_beta = 1 - current_alpha_cumprod / prev_alpha_cumprod
# Compute variance using the formula from the DDPM paper
variance = (1 - prev_alpha_cumprod) / (1 - current_alpha_cumprod) * current_beta
# Ensure variance is not too small
variance = torch.clamp(variance, min=1e-20)
return variance
def set_noise_strength(self, strength=1.0):
"""
Set the strength of noise to add to the input image.
Args:
strength: A value between 0 and 1 indicating the amount of noise
"""
# Calculate the starting step based on strength
start_step = self.num_sampling_steps - int(self.num_sampling_steps * strength)
# Update timesteps to start from the calculated step
self.timesteps = self.timesteps[start_step:]
self.start_step = start_step
def add_noise(self, original_data, time_steps):
"""
Add noise to the original data according to the diffusion process.
Args:
original_data: The original data (e.g., images) to which noise will be added
time_steps: The timesteps at which the noise will be added
Returns:
The noisy data
"""
# Move alpha_cumprod to the same device and dtype as original_data
alpha_cumprod = self.alpha_cumprod.to(device=original_data.device, dtype=original_data.dtype)
# Move time_steps to the same device as original_data
time_steps = time_steps.to(original_data.device)
# Calculate sqrt of cumulative alphas for the given timesteps
sqrt_alpha_prod = alpha_cumprod[time_steps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
# Reshape to match dimensions of original_data
while len(sqrt_alpha_prod.shape) < len(original_data.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
# Calculate sqrt of (1 - cumulative alphas) for the given timesteps
sqrt_one_minus_alpha_prod = (1 - alpha_cumprod[time_steps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
# Reshape to match dimensions of original_data
while len(sqrt_one_minus_alpha_prod.shape) < len(original_data.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
# Sample random noise
random_noise = torch.randn(original_data.shape, generator=self.random_generator,
device=original_data.device, dtype=original_data.dtype)
# Apply the noise addition formula
noisy_data = sqrt_alpha_prod * original_data + sqrt_one_minus_alpha_prod * random_noise
return noisy_data
def diffusion_step(self, timestep, latent_data, model_output):
"""
Perform one step of the diffusion process.
Args:
timestep: Current timestep
latent_data: Latent representation of the data
model_output: Output from the diffusion model
Returns:
The previous sample in the diffusion process
"""
t = timestep
prev_t = self.get_previous_timestep(t)
# Calculate alpha and beta values
alpha_prod_t = self.alpha_cumprod[t]
alpha_prod_t_prev = self.alpha_cumprod[prev_t] if prev_t >= 0 else self.unit
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
# Calculate current alpha and beta
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
current_beta_t = 1 - current_alpha_t
# Predict original sample from model output
predicted_original_sample = (latent_data - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
# Calculate coefficients for the mean calculation
original_sample_coeff = (alpha_prod_t_prev ** 0.5 * current_beta_t) / beta_prod_t
current_sample_coeff = current_alpha_t ** 0.5 * beta_prod_t_prev / beta_prod_t
# Compute predicted previous sample (mean)
predicted_prev_sample = original_sample_coeff * predicted_original_sample + current_sample_coeff * latent_data
# Add variance to the mean
if t > 0:
# Generate random noise
noise = torch.randn(model_output.shape, generator=self.random_generator,
device=model_output.device, dtype=model_output.dtype)
# Compute variance and add it to the predicted sample
variance = (self.compute_variance(t) ** 0.5) * noise
predicted_prev_sample = predicted_prev_sample + variance
return predicted_prev_sample