Detecting manipulated media generated by deeppfake algorithms is a pressing challenge. This article presents an end-to-end pipeline for training a binary image classifier that distinguishes real faces from synthetically generaetd ones using EfficientNet, PyTorch, and the timm library.
Task Overview and Data Format
The objective is to assign a probability of being a deepfake to each face image. The training set contains labeled samples, where label=1 indicates a deepfake and label=0 a genuine face. A validation set is12 provided for model tuning. Each dataset is described by a text file where every line holds an image filename and its target label:
img_name,target
3381ccbc4df9e7778b720d53a2987014.jpg,1
63fee8a89581307c0b4fd05a48e0ff79.jpg,0
...
Performance is measured primarily by the Area Under the ROC Curve (AUC). In case of ties, the True Positive Rate at a low False Positive Rate (e.g., FPR=1E-3) is used as a secondary metric.
Environment Setup
We begin by importing the necessary libraries and configuring PyTorch for reproducibility and performance:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import timm
import pandas as pd
import numpy as np
from PIL import Image
import time
from tqdm import tqdm
# Set random seed and cuDNN flags
torch.manual_seed(0)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
Utility Classes for Metric Tracking
We define a metric tracker that computes and stores the current value and running average, and a progress meter to display400 batch-wise information during training and validation.
class MetricTracker:
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class ProgressDisplay:
def __init__(self, num_batches, *trackers):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.trackers = trackers
self.prefix = ""
def show(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(trk) for trk in self.trackers]
print('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
Custom Dataset Loader
The DeepfakeDataset class reads images from a list of paths, applies optional transformations, and returns the image tensor along with the corresponding label.
class DeepfakeDataset(Dataset):
def __init__(self, img_paths, labels, transform=None):
self.img_paths = img_paths
self.labels = labels
self.transform = transform
def __getitem__(self, index):
image = Image.open(self.img_paths[index]).convert('RGB')
if self.transform:
image = self.transform(image)
label = torch.tensor(self.labels[index]).float()
return image, label
def __len__(self):
return len(self.img_paths)
Training, Validation, and Inference Routines
We implement functions for a single training epoch, evaluating on the validation set, and performing inference on test data. The training loop executes a forward pass, calculates loss and accuracy, and updates model parameters via backpropagation.
def train_one_epoch(dataloader, model, criterion, optimizer, epoch):
batch_time = MetricTracker('Time', ':6.3f')
losses = MetricTracker('Loss', ':.4e')
accuracy = MetricTracker('Acc@1', ':6.2f')
progress = ProgressDisplay(len(dataloader), batch_time, losses, accuracy)
model.train()
end = time.time()
for i, (images, targets) in enumerate(dataloader):
images = images.cuda(non_blocking=True)
targets = targets.cuda(non_blocking=True).long()
outputs = model(images)
loss = criterion(outputs, targets)
losses.update(loss.item(), images.size(0))
preds = outputs.argmax(dim=1)
acc = (preds == targets).float().mean() * 100
accuracy.update(acc.item(), images.size(0))
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if i % 100 == 0:
progress.show(i)
def evaluate(dataloader, model, criterion):
batch_time = MetricTracker('Time', ':6.3f')
losses = MetricTracker('Loss', ':.4e')
accuracy = MetricTracker('Acc@1', ':6.2f')
progress = ProgressDisplay(len(dataloader), batch_time, losses, accuracy)
model.eval()
with torch.no_grad():
end = time.time()
for i, (images, targets) in enumerate(tqdm(dataloader, total=len(dataloader))):
images = images.cuda()
targets = targets.cuda().long()
outputs = model(images)
loss = criterion(outputs, targets)
preds = outputs.argmax(dim=1)
acc = (preds == targets).float().mean() * 100
losses.update(loss.item(), images.size(0))
accuracy.update(acc.item(), images.size(0))
batch_time.update(time.time() - end)
end = time.time()
print(' * Val Acc@1 {top1.avg:.3f}'.format(top1=accuracy))
return accuracy
def infer(dataloader, model, tta=1):
model.eval()
predictions_tta = None
for _ in range(tta):
preds = []
with torch.no_grad():
for images, _ in tqdm(dataloader, total=len(dataloader)):
images = images.cuda()
output = model(images)
probs = F.softmax(output, dim=1)
preds.append(probs.cpu().numpy())
preds = np.vstack(preds)
if predictions_tta is None:
predictions_tta = preds
else:
predictions_tta += preds
return predictions_tta / tta
Model Architecture
We load a pretrained EfficientNet-B1 via timm and adapt the classifier head for our binary task. The model is moved to the GPU.
model = timm.create_model('efficientnet_b1', pretrained=True, num_classes=2)
model = model.cuda()
Data Preparation and Augmentation
Labels are read from CSV-like files, and paths are constructed accordingly. The training loader applies strong data augmentation (RandAugment), while validation and test sets only undergo resizing and normalization.
train_df = pd.read_csv('train_label.txt')
val_df = pd.read_csv('val_label.txt')
base_train_path = 'trainset/'
base_val_path = 'valset/'
train_df['full_path'] = base_train_path + train_df['img_name']
val_df['full_path'] = base_val_path + val_df['img_name']
batch_size_value = 32
train_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandAugment(num_ops=2, magnitude=9),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = DeepfakeDataset(train_df['full_path'].tolist(), train_df['target'].tolist(), transform=train_transform)
val_dataset = DeepfakeDataset(val_df['full_path'].tolist(), val_df['target'].tolist(), transform=val_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size_value, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size_value, shuffle=False, num_workers=4, pin_memory=True)
Training Configuration and Execution
The model is trained using the cross-entropy loss and the Adam optimizer. A step-based learning rate scheduler reduces the learning rate periodically. We save checkpoints whenever validation accuracy improves.
criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.005)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.85)
num_epochs = 10
best_acc = 0.0
for epoch in range(num_epochs):
scheduler.step()
print('Epoch: ', epoch)
train_one_epoch(train_loader, model, criterion, optimizer, epoch)
val_acc = evaluate(val_loader, model, criterion)
current_acc = val_acc.avg.item()
if current_acc > best_acc:
best_acc = round(current_acc, 2)
torch.save(model.state_dict(), f'model_{best_acc}.pt')
Inference on Test Data
Once the best model is selected, we run inference on the16 test split (here demonstrated on the validation set) and save the predicted probabilities as a submission file.
test_loader = DataLoader(val_dataset, batch_size=batch_size_value, shuffle=False, num_workers=4, pin_memory=True)
probs = infer(test_loader, model, tta=1)[:, 1] # probability of being a deepfake
result_df = val_df[['img_name']].copy()
result_df['y_pred'] = probs
result_df.to_csv('submit.csv', index=False)