Synthetic Training Data Generation
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
# Create synthetic dataset
train_x = torch.linspace(-1, 1, 100).view(-1, 1)
train_y = train_x ** 2 + 0.2 * torch.rand(train_x.size())
# Visualize input-output distribution
plt.scatter(train_x.numpy(), train_y.numpy())
plt.show()
Persisting a Trained Model
def persist_model():
model_a = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
optim = torch.optim.SGD(model_a.parameters(), lr=0.5)
criterion = torch.nn.MSELoss()
for epoch in range(100):
output = model_a(train_x)
err = criterion(output, train_y)
optim.zero_grad()
err.backward()
optim.step()
# Save full model object
torch.save(model_a, 'full_model.pt')
# Save only learnable parameters (recommended for speed)
torch.save(model_a.state_dict(), 'model_weights.pt')
# Visualize fitted curve
plt.figure(figsize=(10, 3))
plt.subplot(1, 3, 1)
plt.title('Trained Model A')
plt.scatter(train_x.numpy(), train_y.numpy())
plt.plot(train_x.numpy(), output.detach().numpy(), 'r-', linewidth=5)
Storing just the parameter tensors avoids serializing the entire architecture, resulting in faster I/O and smaller files.
Restoring from Checkpoint
def reload_full():
model_b = torch.load('full_model.pt')
pred_b = model_b(train_x)
plt.subplot(1, 3, 2)
plt.title('Reloaded Full Model')
plt.scatter(train_x.numpy(), train_y.numpy())
plt.plot(train_x.numpy(), pred_b.detach().numpy(), 'r-', linewidth=5)
def reload_weights():
# Architecture must match exactly when loading state_dict
model_c = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
model_c.load_state_dict(torch.load('model_weights.pt'))
pred_c = model_c(train_x)
plt.subplot(1, 3, 3)
plt.title('Reloaded Weights Only')
plt.scatter(train_x.numpy(), train_y.numpy())
plt.plot(train_x.numpy(), pred_c.detach().numpy(), 'r-', linewidth=5)
When restoring weights alone, reconstruct the identical network definition before applying load_state_dict.
Execution Flow
persist_model()
reload_full()
reload_weights()
Running the sequence generates three overalid plots comparing the original fit and both restoration methods.