Saving and Loading Models in PyTorch Networks

Synthetic Training Data Generation

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# Create synthetic dataset
train_x = torch.linspace(-1, 1, 100).view(-1, 1)
train_y = train_x ** 2 + 0.2 * torch.rand(train_x.size())

# Visualize input-output distribution
plt.scatter(train_x.numpy(), train_y.numpy())
plt.show()

Persisting a Trained Model

def persist_model():
    model_a = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )
    optim = torch.optim.SGD(model_a.parameters(), lr=0.5)
    criterion = torch.nn.MSELoss()

    for epoch in range(100):
        output = model_a(train_x)
        err = criterion(output, train_y)
        optim.zero_grad()
        err.backward()
        optim.step()

    # Save full model object
    torch.save(model_a, 'full_model.pt')
    # Save only learnable parameters (recommended for speed)
    torch.save(model_a.state_dict(), 'model_weights.pt')

    # Visualize fitted curve
    plt.figure(figsize=(10, 3))
    plt.subplot(1, 3, 1)
    plt.title('Trained Model A')
    plt.scatter(train_x.numpy(), train_y.numpy())
    plt.plot(train_x.numpy(), output.detach().numpy(), 'r-', linewidth=5)

Storing just the parameter tensors avoids serializing the entire architecture, resulting in faster I/O and smaller files.

Restoring from Checkpoint

def reload_full():
    model_b = torch.load('full_model.pt')
    pred_b = model_b(train_x)

    plt.subplot(1, 3, 2)
    plt.title('Reloaded Full Model')
    plt.scatter(train_x.numpy(), train_y.numpy())
    plt.plot(train_x.numpy(), pred_b.detach().numpy(), 'r-', linewidth=5)

def reload_weights():
    # Architecture must match exactly when loading state_dict
    model_c = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )
    model_c.load_state_dict(torch.load('model_weights.pt'))
    pred_c = model_c(train_x)

    plt.subplot(1, 3, 3)
    plt.title('Reloaded Weights Only')
    plt.scatter(train_x.numpy(), train_y.numpy())
    plt.plot(train_x.numpy(), pred_c.detach().numpy(), 'r-', linewidth=5)

When restoring weights alone, reconstruct the identical network definition before applying load_state_dict.

Execution Flow

persist_model()
reload_full()
reload_weights()

Running the sequence generates three overalid plots comparing the original fit and both restoration methods.

Tags: pytorch Model Persistence Deep Learning Neural Networks Training Workflow

Posted on Thu, 07 May 2026 11:30:23 +0000 by ArmanIc