Easy Self-Supervised Learning with BYOL

BYOL is a surprisingly simple method to leverage unlabeled image data and improve your deep learning models for computer vision.

Frank Odom
The DL
Published in
7 min readNov 5, 2020

--

Photo by Djamal Akhmad Fahmi on Unsplash

Note: All code from this article is available in this Google Colab notebook. You can use Colab’s free GPU to run/modify the experiments yourself.

Self-Supervised Learning

Too often in deep learning, there just isn’t enough labelled data. Manually labeling data is too time intensive, and outsourcing the labor can be prohibitively expensive for small companies or individuals. Self-supervised learning is a nascent sub-field of deep learning, which aims to alleviate your data problems by learning from unlabeled samples. The goal is simple: train a model so that similar samples have similar representations. Accomplishing that is (usually) not so simple, but years of research from leaders like Google have greatly advanced this field.

Before BYOL, most attempts at self-supervised learning could be categorized as either contrastive or generative learning methods. Generative learning uses GANs to model the complete data distribution, which can be extremely computationally expensive. Contrastive methods are much less expensive. As described by the authors of BYOL:

Contrastive approaches avoid a costly generation step in pixel space by bringing representation of different views of the same image closer (‘positive pairs’), and spreading representations of views from different images (‘negative pairs’) apart.

For this to work well, though, we must compare each sample to many other negative samples. This is problematic, because it introduces instabilities into our training, and reinforces systematic biases from the dataset. The BYOL authors describe this very clearly:

Contrastive methods are sensitive to the choice of image augmentations. For instance, SimCLR does not work well when removing color distortion from its image augmentations. As an explanation, SimCLR shows that crops of the same image mostly share their color histograms. At the same time, color histograms vary across images. Therefore, when a contrastive task only relies on random crops as image augmentations, it can be mostly solved by focusing on color histograms alone. As a result the representation is not incentivized to retain information beyond color histograms.

This also occurs for other types of data transformations — not just color distortions. In general, contrastive training will be sensitive to systematic biases in your data. Data bias is a wide-spread issue in machine learning (see: facial recognition for women and minorities), and it’s a very serious problem for contrastive methods. Luckily, BYOL does not depend on negative sampling, which provides an escape from this problem.

BYOL: Bootstrap Your Own Latent

The goal of BYOL is similar to contrastive learning, but with one big difference. BYOL does not worry about whether dissimilar samples have dissimilar representations (the contrastive part of contrastive learning). We only care that similar samples have similar representations. This may seem like a subtle difference, but it has big implications for training efficiency and generalization:

  1. Training is more efficient, because BYOL does not require negative sampling. We only sample each training example once per epoch. The negative counterparts can be ignored altogether.
  2. Our model is less sensitive to systematic biases in the training dataset. Usually, this means that it generalizes better to unseen examples.

BYOL minimizes the distance between representations of each sample and a transformation of that sample. Examples of transformations include: translation, rotation, blurring, color inversion, color jitter, gaussian noise, etc. (I’m using images as a concrete example here, but BYOL works with other data types, too.) We usually train using several different types of transformations, which can be applied together or independently. In general, if you want your model to be invariant under a particular transformation, then it should be included it in your training.

Coding BYOL from Scratch

Let’s start by coding the transformations. The BYOL authors use a particular set of transformations, which are similar to those used in SimCLR:

I chose to use Kornia for implementing the transformations — a great Python library with fully differentiable computer vision operations. You could use any other data augmentation/transformation library, or simply write your own. We don’t actually need differentiability for implementing BYOL.

Next, we need an Encoder module. The Encoder is responsible for extracting features from the base model, and projecting those features into a lower-dimensional, latent space. We’ll implement it using a wrapper class, which allows us to easily use BYOL with any model — not just one that we hard-code into our scripts. There are two primary components.

  1. Feature Extractor: collects the outputs from one of the last model layers.
  2. Projector: a linear layer, which projects outputs down lower dimensions.

The feature extraction is implemented using hooks. (If you’re not familiar with them, see my previous article How to Use PyTorch Hooks for an overview and tutorial.) Other than that, the wrapper is pretty straightforward.

BYOL contains two identical Encoder networks. The first is trained as usual, and its weights are updated with each training batch. The second (referred to as the “target” network) is updated using a running average of the first Encoder’s weights. During training, the target network is provided a raw training batch, and the other Encoder is given a transformed version of the same batch. Each network generates a low-dimensional, latent representation for their respective data. Then, we attempt to predict the output of the target network using a multi-layer perceptron. BYOL maximizes the similarity between this prediction and the target network’s output.

Source: Bootstrap Your Own Latent, Figure 2

Why include the multi-layer perceptron? If we want similar representations before/after transforming the data, shouldn’t we just compare the latent vectors directly? Actually, no. In that case, our network would quickly learn similar representations for all images by decreasing its weights to zero. Our model would have learned nothing at all. Instead, our MLP layer learns to recognize the data transformations and predict the target latent vector. The weights no longer collapse to zero, and we can continue learning self-consistent representations for our data!

At the end of training, we discard the target network altogether. This leaves a single Encoder, which has been trained to generate self-consistent representations for all samples in the training data. This is exactly why BYOL works for self-supervised learning! Because the learned representations are self-consistent, they are (mostly) invariant under different transformations of the data. Similar examples have similar representations in the trained model!

Now, we need to write the BYOL training code. I chose to use PyTorch Lightning for this. It’s a fantastic library for deep learning projects/research written in PyTorch, which includes conveniences like multi-GPU training, experiment logging, model checkpointing, and mixed-precision training. (You can now even run PyTorch models on cloud TPUs with Lightning!)

Most of this is boilerplate code for interfacing with PyTorch Lightning. The important part happens in training_step, where all of the data transformations, feature projections, and similarity losses are computed.

Practical Example

Time to see BYOL in action. As a practical example, we’ll be using the STL10 dataset. It’s perfect for unsupervised and self-supervised learning experiments because it contains a large number of unlabeled images, as well as labeled training and test sets. As described on the STL10 site:

The STL-10 dataset is an image recognition dataset for developing unsupervised feature learning, deep learning, self-taught learning algorithms. It is inspired by the CIFAR-10 dataset but with some modifications. In particular, each class has fewer labeled training examples than in CIFAR-10, but a very large set of unlabeled examples is provided to learn image models prior to supervised training. The primary challenge is to make use of the unlabeled data (which comes from a similar but different distribution from the labeled data) to build a useful prior.

Torchvision has convenient methods for loading STL10, so we don’t need to worry about downloading or pre-processing the data in any way.

As a baseline, we first perform supervised training, and then measure the accuracy of our trained model. We can write another (much simpler) Lightning module to accomplish this:

Now, training with PyTorch Lightning is pretty straightforward. Just create DataLoader objects for the training and test sets, and specify the model we want to train. I chose to train for 25 epochs with a learning rate of 1e-4.

We achieve around 85% accuracy — not bad, given a very small model like ResNet18. But naturally, we’re not happy with 85% accuracy, because we can do better!

For the next experiment, we’ll pre-train the ResNet18 model using BYOL. I chose to train for 50 epochs, using a learning rate of 1e-4 again. (This is by far the most computationally intensive step. It takes roughly 45 minutes in a standard Colab notebook with K80 GPU.)

Then, we extract the newly trained ResNet18 model, and run supervised training again. (To ensure that forward hooks from BYOL are removed, we instantiate a new model and copy the trained state dictionary over to it.)

Just like that, we’ve boosted the model accuracy by roughly 2.5%, up to 87.7% overall! It required a decent amount of code (~300 lines) and some helpful libraries, but this was simpler than many other self-supervised methods. (For comparison, take a look at the official SimCLR or SwAV repositories.) And the entire experiment takes less than an hour, even when using modest hardware provided for free through Colab.

Conclusions

There are some very interesting takeaways here. First (and most obviously), BYOL is a pretty cool self-supervised method, which can maximize your model performance by leveraging unlabeled data. What’s even more interesting is that BYOL outperformed pre-trained ResNet18, since all ResNet models are pre-trained using ImageNet. STL10 is a small subset of ImageNet with all images downsized from 224x224 a resolution of 96x96. Because of the change in resolution, we need self-supervised learning to recover some of the model performance. The small training set provided in STL10 is just not enough to accomplish that alone.

ML practitioners often rely too heavily on pre-trained weights in models like ResNet. They’re certainly useful, but they aren’t necessarily well-suited for other data domains — even very similar data to ImageNet, such as STL10. For that reason, I hope the next few years lead to broad adoption of self-supervised methods in deep learning workflows.

References

https://arxiv.org/pdf/2006.07733.pdf
https://arxiv.org/pdf/2006.10029v2.pdf
https://github.com/fkodom/byol
https://github.com/lucidrains/byol-pytorch
https://github.com/google-research/simclr
http://image-net.org/
https://cs.stanford.edu/~acoates/stl10/

--

--