Mixed Barlow Twins
Guarding Barlow Twins Against Overfitting with Mixed Samples
Wele Gedara Chaminda Bandara (Johns Hopkins University), Celso M. De Melo (U.S. Army Research Laboratory), and Vishal M. Patel (Johns Hopkins University)
1 Overview of Mixed Barlow Twins
TL;DR
- Mixed Barlow Twins aims to improve sample interaction during Barlow Twins training via linearly interpolated samples.
- We introduce an additional regularization term to the original Barlow Twins objective, assuming linear interpolation in the input space translates to linearly interpolated features in the feature space.
- Pre-training with this regularization effectively mitigates feature overfitting and further enhances the downstream performance on
CIFAR-10
,CIFAR-100
,TinyImageNet
,STL-10
, andImageNet
datasets.
$C^{MA} = (Z^M)^TZ^A$
$C^{MB} = (Z^M)^TZ^B$
$C^{MA}_{gt} = \lambda (Z^A)^TZ^A + (1-\lambda)\mathtt{Shuffle}^*(Z^B)^TZ^A$
$C^{MB}_{gt} = \lambda (Z^A)^TZ^B + (1-\lambda)\mathtt{Shuffle}^*(Z^B)^TZ^B$
2 Usage
2.1 Requirements
Before using this repository, make sure you have the following prerequisites installed:
You can install PyTorch with the following command (in Linux OS):
conda install pytorch torchvision pytorch-cuda=11.8 -c pytorch -c nvidia
2.2 Installation
To get started, clone this repository:
git clone https://github.com/wgcban/mix-bt.git
Next, create the conda environment named ssl-aug
by executing the following command:
conda env create -f environment.yml
All the train-val-test statistics will be automatically upload to wandb
, and please refer wandb-quick-start
documentation if you are not familiar with using wandb
.
2.3 Supported Pre-training Datasets
This repository supports the following pre-training datasets:
CIFAR-10
: https://www.cs.toronto.edu/~kriz/cifar.htmlCIFAR-100
: https://www.cs.toronto.edu/~kriz/cifar.htmlTiny-ImageNet
: https://github.com/rmccorm4/Tiny-Imagenet-200STL-10
: https://cs.stanford.edu/~acoates/stl10/ImageNet
: https://www.image-net.org
CIFAR-10
, CIFAR-100
, and STL-10
datasets are directly available in PyTorch.
To use TinyImageNet
, please follow the preprocessing instructions provided in the TinyImageNet-Script. Download these datasets and place them in the data
directory.
2.4 Supported Transfer Learning Datasets
You can download and place transfer learning datasets under their respective paths, such as 'data/DTD'. The supported transfer learning datasets include:
DTD
: https://www.robots.ox.ac.uk/~vgg/data/dtd/MNIST
: http://yann.lecun.com/exdb/mnist/FashionMNIST
: https://github.com/zalandoresearch/fashion-mnistCUBirds
: http://www.vision.caltech.edu/visipedia/CUB-200-2011.htmlVGGFlower
: https://www.robots.ox.ac.uk/~vgg/data/flowers/102/Traffic Signs
: https://benchmark.ini.rub.de/gtsdb_dataset.htmlAircraft
: https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/
2.5 Supported SSL Methods
This repository supports the following Self-Supervised Learning (SSL) methods:
SimCLR
: contrastive learning for SSLBYOL
: distilation for SSLWitening MSE
: infomax for SSLBarlow Twins
: infomax for SSLMixed Barlow Twins (ours)
: infomax + mixed samples for SSL
2.6 Pre-Training with Mixed Barlow Twins
To start pre-training and obtain k-NN evaluation results for Mixed Barlow Twins on CIFAR-10
, CIFAR-100
, TinyImageNet
, and STL-10
with ResNet-18/50
backbones, please run:
sh scripts-pretrain-resnet18/[dataset].sh
sh scripts-pretrain-resnet50/[dataset].sh
To start the pre-training on ImageNet
with ResNet-50
backbone, please run:
sh scripts-pretrain-resnet18/imagenet.sh
2.7 Linear Evaluation of Pre-trained Models
Before running linear evaluation, ensure that you specify the model_path
argument correctly in the corresponding .sh file.
To obtain linear evaluation results on CIFAR-10
, CIFAR-100
, TinyImageNet
, STL-10
with ResNet-18/50
backbones, please run:
sh scripts-linear-resnet18/[dataset].sh
sh scripts-linear-resnet50/[dataset].sh
To obtain linear evaluation results on ImageNet
with ResNet-50
backbone, please run:
sh scripts-linear-resnet50/imagenet_sup.sh
2.8 Transfer Learning of Pre-trained Models
To perform transfer learning from pre-trained models on CIFAR-10
, CIFAR-100
, and STL-10
to fine-grained classification datasets, execute the following command, making sure to specify the model_path
argument correctly:
sh scripts-transfer-resnet18/[dataset]-to-x.sh
3 Pre-Trained Checkpoints
Download the pre-trained models from checkpoints/
and store them in checkpoints/
. This repository provides pre-trained checkpoints for both ResNet-18
and ResNet-50
architectures.
3.1 ResNet-18
Dataset | $d$ | $\lambda_{BT}$ | $\lambda_{reg}$ | Download Link to Pretrained Model | KNN Acc. | Linear Acc. |
---|---|---|---|---|---|---|
CIFAR-10 |
1024 | 0.0078125 | 4.0 | 4wdhbpcf_0.0078125_1024_256_cifar10_model.pth | 90.52 | 92.58 |
CIFAR-100 |
1024 | 0.0078125 | 4.0 | 76kk7scz_0.0078125_1024_256_cifar100_model.pth | 61.25 | 69.31 |
TinyImageNet |
1024 | 0.0009765 | 4.0 | 02azq6fs_0.0009765_1024_256_tiny_imagenet_model.pth | 38.11 | 51.67 |
STL-10 |
1024 | 0.0078125 | 2.0 | i7det4xq_0.0078125_1024_256_stl10_model.pth | 88.94 | 91.02 |
3.2 ResNet-50
Dataset | $d$ | $\lambda_{BT}$ | $\lambda_{reg}$ | Download Link to Pretrained Model | KNN Acc. | Linear Acc. |
---|---|---|---|---|---|---|
CIFAR-10 |
1024 | 0.0078125 | 4.0 | v3gwgusq_0.0078125_1024_256_cifar10_model.pth | 91.39 | 93.89 |
CIFAR-100 |
1024 | 0.0078125 | 4.0 | z6ngefw7_0.0078125_1024_256_cifar100_model.pth | 64.32 | 72.51 |
TinyImageNet |
1024 | 0.0009765 | 4.0 | kxlkigsv_0.0009765_1024_256_tiny_imagenet_model.pth | 42.21 | 51.84 |
STL-10 |
1024 | 0.0078125 | 2.0 | pbknx38b_0.0078125_1024_256_stl10_model.pth | 87.79 | 91.70 |
On ImageNet
# Epochs | $d$ | $\lambda_{BT}$ | $\lambda_{reg}$ | Download Link to Pretrained Model | Linear Acc. |
---|---|---|---|---|---|
300 | 8192 | 0.0051 | 0.0 (BT) | 3on0l4wl_0.0000_8192_1024_imagenet_resnet50.pth | 71.3 |
300 | 8192 | 0.0051 | 0.0025 | l418b9zw_0.0025_8192_1024_imagenet_resnet50.pth | 70.9 |
300 | 8192 | 0.0051 | 0.1 | 13awtq23_0.1000_8192_1024_imagenet_resnet50.pth | 71.6 |
300 | 8192 | 0.0051 | 1.0 | 3fb1op86_1.0000_8192_1024_imagenet_resnet50.pth | 72.2 |
300 | 8192 | 0.0051 | 3.0 | TBU | TBU |
300 | 8192 | 0.0051 | 5.0 | TBU | TBU |
4 Training Statistics
Here we provide some training and validation (linear probing) statistics for Barlow Twins vs. Mixed Barlow Twins with ResNet-50
backbone on ImageNet
:
5 Disclaimer
A large portion of the code is from Barlow Twins HSIC (for experiments on small datasets: CIFAR-10
, CIFAR-100
, TinyImageNet
, and STL-10
) and official implementation of Barlow Twins here (for experiments on ImageNet
), which is a great resource for academic development.
Also, note that the implementation of SOTA methods (SimCLR, BYOL, and Witening-MSE) in ssl-sota
are copied from Witening-MSE.
We would like to thank all of them for making their repositories publicly available for the research community. 🙏
6 Reference
If you feel our work is useful, please consider citing our work. Thanks!
@misc{bandara2023guarding,
title={Guarding Barlow Twins Against Overfitting with Mixed Samples},
author={Wele Gedara Chaminda Bandara and Celso M. De Melo and Vishal M. Patel},
year={2023},
eprint={2312.02151},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
7 License
This code is under MIT licence, you can find the complete file here.