PyTorch / Torchvision Learnings

As part of our machine learning user group workshops I learned a few things about PyTorch and torchvision.

This post describes some of them.

To crop images as part of the transform step I wanted to use the functional transforms. This is how I used them as a class in a transform.Compose:

class CustomCropTransform:
    def __call__(self, img):
        return torchvision.transforms.functional.crop(img, top=0, left=0, height=75, width=133)

transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize(100),
    CustomCropTransform(),
    torchvision.transforms.ToTensor()
])

# example of usage
train_dataset = torchvision.datasets.ImageFolder("train", transform=transforms)

Here the image is first resized to 100x133 and then the bottom 25 pixels got removed by the CustomCropTransform

The second learning was about WeightedRandomSampler. This is useful if the classes in your training dataset are not the same number of items.

from torch.utils.data.sampler import WeightedRandomSampler
import numpy as np

# get targets of all train_datasets
train_targets = [target for _, target in train_dataset]

# use bincount from numpy to count the number of items in each class
counts = np.bincount(train_targets)

# get the weight for each class. this returns a matrix of weights
weight = 1. / counts

# now weight all the training items
train_samples_weight = torch.tensor([weight[t] for t in train_targets])

# use the weights. replacement=True allows using of samples more than once
train_sampler = WeightedRandomSampler(train_samples_weight, len(train_targets), replacement=True)

# now use the train_sampler in the dataloader
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=16, sampler=train_sampler)

The workshop series consists of problem driven walkthroughs: https://github.com/mlugs/machine-learning-workshop